FScanpy-package/FScanpy/predictor.py

434 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import pickle
import numpy as np
import pandas as pd
from tensorflow.keras.models import load_model
from .features.sequence import SequenceFeatureExtractor
from .features.cnn_input import CNNInputProcessor
from .utils import extract_window_sequences
import matplotlib.pyplot as plt
import joblib
class PRFPredictor:
def __init__(self, model_dir=None):
"""
初始化PRF预测器
Args:
model_dir: 模型目录路径(可选)
"""
if model_dir is None:
from pkg_resources import resource_filename
model_dir = resource_filename('FScanpy', 'pretrained')
try:
# 加载模型
self.gb_model = self._load_pickle(os.path.join(model_dir, 'GradientBoosting_all.pkl'))
self.cnn_model = self._load_pickle(os.path.join(model_dir, 'BiLSTM-CNN_all.pkl'))
self.voting_model = self._load_pickle(os.path.join(model_dir, 'Voting_all.pkl'))
# 初始化特征提取器和CNN处理器使用与训练时相同的序列长度
self.gb_seq_length = 33 # HistGradientBoosting使用的序列长度
self.cnn_seq_length = 399 # BiLSTM-CNN使用的序列长度
# 初始化特征提取器和CNN输入处理器
self.feature_extractor = SequenceFeatureExtractor(seq_length=self.gb_seq_length)
self.cnn_processor = CNNInputProcessor(max_length=self.cnn_seq_length)
except FileNotFoundError as e:
raise FileNotFoundError(f"无法找到模型文件: {str(e)}")
except Exception as e:
raise Exception(f"加载模型出错: {str(e)}")
def _load_pickle(self, path):
return joblib.load(path)
def predict_single_position(self, fs_period, full_seq, gb_threshold=0.1):
'''
预测单个位置的PRF状态
Args:
fs_period: 33bp序列 (将根据gb_seq_length处理)
full_seq: 完整序列 (将根据cnn_seq_length处理)
gb_threshold: GB模型的概率阈值 (默认为0.1)
Returns:
dict: 包含预测概率的字典
'''
try:
# 处理序列长度
if len(fs_period) > self.gb_seq_length:
fs_period = self.feature_extractor.trim_sequence(fs_period, self.gb_seq_length)
# GB模型预测 - 确保输入是二维数组
try:
gb_features = self.feature_extractor.extract_features(fs_period)
# 检查特征结构并确保是一维数组
if isinstance(gb_features, np.ndarray):
# 如果是多维数组,进行扁平化处理
if gb_features.ndim > 1:
print(f"警告: 特征是{gb_features.ndim}维数组,进行扁平化处理")
gb_features = gb_features.flatten()
# 明确将特征转换为二维数组,正确形状为(1, n_features)
gb_features_2d = np.array([gb_features])
# 再次检查维度
if gb_features_2d.ndim != 2:
raise ValueError(f"处理后特征仍为{gb_features_2d.ndim}维,需要二维数组")
gb_prob = self.gb_model.predict_proba(gb_features_2d)[0][1]
except Exception as e:
print(f"GB模型预测时出错: {str(e)}")
# 出错时设置概率为0
gb_prob = 0.0
# 如果GB概率低于阈值则跳过CNN模型
if gb_prob < gb_threshold:
return {
'GB_Probability': gb_prob,
'CNN_Probability': 0.0,
'Voting_Probability': 0.0
}
# CNN模型预测
try:
# 首先检查CNN模型的类型 - 通过尝试识别模型类型
is_sklearn_model = False
# 检测模型类型的方法
if hasattr(self.cnn_model, 'predict_proba'):
# 这可能是一个scikit-learn模型
is_sklearn_model = True
if is_sklearn_model:
# 如果是sklearn模型 (如HistGradientBoostingClassifier)使用与GB相同的特征提取
# 为CNN模型使用相同的特征提取方法但从399bp序列中提取
cnn_features = self.feature_extractor.extract_features(full_seq)
if isinstance(cnn_features, np.ndarray) and cnn_features.ndim > 1:
cnn_features = cnn_features.flatten()
# 转为二维数组
cnn_features_2d = np.array([cnn_features])
cnn_pred = self.cnn_model.predict_proba(cnn_features_2d)
cnn_prob = cnn_pred[0][1]
else:
# 假设是深度学习模型,需要三维输入
cnn_input = self.cnn_processor.prepare_sequence(full_seq)
# 尝试不同的预测方法
try:
# 先尝试不带参数
cnn_pred = self.cnn_model.predict(cnn_input)
except TypeError:
try:
# 再尝试带verbose参数
cnn_pred = self.cnn_model.predict(cnn_input, verbose=0)
except Exception:
# 最后尝试将输入重塑为2D
reshaped_input = cnn_input.reshape(1, -1)
cnn_pred = self.cnn_model.predict(reshaped_input)
# 处理预测结果
if isinstance(cnn_pred, list):
cnn_pred = cnn_pred[0]
# 提取概率值
if hasattr(cnn_pred, 'shape') and len(cnn_pred.shape) > 1 and cnn_pred.shape[1] > 1:
cnn_prob = cnn_pred[0][1]
else:
cnn_prob = cnn_pred[0][0] if hasattr(cnn_pred[0], '__getitem__') else cnn_pred[0]
except Exception as e:
print(f"CNN模型预测时出错: {str(e)}")
# 出错时设置概率为0
cnn_prob = 0.0
# 投票模型预测
try:
# 确保投票模型输入是二维数组 (1, n_features)
voting_input = np.array([[gb_prob, cnn_prob]])
voting_prob = self.voting_model.predict_proba(voting_input)[0][1]
except Exception as e:
print(f"投票模型预测时出错: {str(e)}")
# 出错时使用两个模型的平均值
voting_prob = (gb_prob + cnn_prob) / 2
return {
'GB_Probability': gb_prob,
'CNN_Probability': cnn_prob,
'Voting_Probability': voting_prob
}
except Exception as e:
raise Exception(f"预测过程出错: {str(e)}")
def predict_full(self, sequence, window_size=3, gb_threshold=0.1, plot=False):
"""
预测完整序列中的PRF位点
Args:
sequence: 输入DNA序列
window_size: 滑动窗口大小 (默认为3)
gb_threshold: GB模型概率阈值 (默认为0.1)
plot: 是否绘制预测结果图表 (默认为False)
Returns:
if plot=False:
pd.DataFrame: 包含预测结果的DataFrame
if plot=True:
tuple: (pd.DataFrame, matplotlib.figure.Figure)
"""
if window_size < 1:
raise ValueError("窗口大小必须大于等于1")
if gb_threshold < 0:
raise ValueError("GB阈值必须大于等于0")
results = []
try:
# 确保序列为字符串并转换为大写
sequence = str(sequence).upper()
# 滑动窗口预测
for pos in range(0, len(sequence) - 2, window_size):
# 提取窗口序列 - 使用与训练时相同的窗口大小
fs_period, full_seq = extract_window_sequences(sequence, pos)
if fs_period is None or full_seq is None:
continue
# 预测并记录结果
pred = self.predict_single_position(fs_period, full_seq, gb_threshold)
pred.update({
'Position': pos,
'Codon': sequence[pos:pos+3],
'33bp': fs_period,
'399bp': full_seq
})
results.append(pred)
# 创建结果DataFrame
results_df = pd.DataFrame(results)
# 如需绘图
if plot:
# 创建图形
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 10), height_ratios=[2, 1])
# 绘制折线图
ax1.plot(results_df['Position'], results_df['GB_Probability'],
label='GB模型', alpha=0.7, linewidth=1.5)
ax1.plot(results_df['Position'], results_df['CNN_Probability'],
label='CNN模型', alpha=0.7, linewidth=1.5)
ax1.plot(results_df['Position'], results_df['Voting_Probability'],
label='投票模型', linewidth=2, color='red')
ax1.set_xlabel('序列位置')
ax1.set_ylabel('移码概率')
ax1.set_title('移码预测概率')
ax1.legend()
ax1.grid(True, alpha=0.3)
# 准备热图数据
positions = results_df['Position'].values
probabilities = results_df['Voting_Probability'].values
# 创建热图矩阵
heatmap_matrix = np.zeros((1, len(positions)))
heatmap_matrix[0, :] = probabilities
# 绘制热图
im = ax2.imshow(heatmap_matrix, aspect='auto', cmap='YlOrRd',
extent=[min(positions), max(positions), 0, 1])
# 添加颜色条
cbar = plt.colorbar(im, ax=ax2)
cbar.set_label('移码概率')
# 设置热图轴标签
ax2.set_xlabel('序列位置')
ax2.set_title('移码概率热图')
ax2.set_yticks([])
# 调整布局
plt.tight_layout()
return results_df, fig
return results_df
except Exception as e:
raise Exception(f"序列预测过程出错: {str(e)}")
def predict_region(self, seq, gb_threshold=0.1):
'''
预测区域序列
Args:
seq: 399bp序列或包含399bp序列的DataFrame/Series
gb_threshold: GB模型概率阈值 (默认为0.1)
Returns:
DataFrame: 包含所有序列预测概率的DataFrame
'''
try:
# 如果输入是DataFrame或Series转换为列表
if isinstance(seq, (pd.DataFrame, pd.Series)):
seq = seq.tolist()
# 如果输入是单个字符串,转换为列表
if isinstance(seq, str):
seq = [seq]
results = []
for i, seq399 in enumerate(seq):
try:
# 从399bp序列中截取中心的33bp (GB模型使用)
seq33 = self._extract_center_sequence(seq399, target_length=self.gb_seq_length)
# GB模型预测 - 确保输入是二维数组
try:
gb_features = self.feature_extractor.extract_features(seq33)
# 检查特征结构并确保是一维数组
if isinstance(gb_features, np.ndarray):
# 如果是多维数组,进行扁平化处理
if gb_features.ndim > 1:
print(f"警告: 序列 {i+1} 的特征是{gb_features.ndim}维数组,进行扁平化处理")
gb_features = gb_features.flatten()
# 明确将特征转换为二维数组,正确形状为(1, n_features)
gb_features_2d = np.array([gb_features])
# 再次检查维度
if gb_features_2d.ndim != 2:
raise ValueError(f"处理后特征仍为{gb_features_2d.ndim}维,需要二维数组")
gb_prob = self.gb_model.predict_proba(gb_features_2d)[0][1]
except Exception as e:
print(f"GB模型预测序列 {i+1} 时出错: {str(e)}")
# 出错时设置概率为0
gb_prob = 0.0
# 如果GB概率低于阈值添加低概率结果
if gb_prob < gb_threshold:
results.append({
'GB_Probability': gb_prob,
'CNN_Probability': 0.0,
'Voting_Probability': 0.0,
'33bp': seq33,
'399bp': seq399
})
continue
# CNN模型预测
try:
# 首先检查CNN模型的类型 - 通过尝试识别模型类型
is_sklearn_model = False
# 检测模型类型的方法
if hasattr(self.cnn_model, 'predict_proba'):
# 这可能是一个scikit-learn模型
is_sklearn_model = True
if is_sklearn_model:
# 如果是sklearn模型 (如HistGradientBoostingClassifier)使用与GB相同的特征提取
# 为CNN模型使用相同的特征提取方法但从399bp序列中提取
cnn_features = self.feature_extractor.extract_features(seq399)
if isinstance(cnn_features, np.ndarray) and cnn_features.ndim > 1:
cnn_features = cnn_features.flatten()
# 转为二维数组
cnn_features_2d = np.array([cnn_features])
cnn_pred = self.cnn_model.predict_proba(cnn_features_2d)
cnn_prob = cnn_pred[0][1]
else:
# 假设是深度学习模型,需要三维输入
cnn_input = self.cnn_processor.prepare_sequence(seq399)
# 尝试不同的预测方法
try:
# 先尝试不带参数
cnn_pred = self.cnn_model.predict(cnn_input)
except TypeError:
try:
# 再尝试带verbose参数
cnn_pred = self.cnn_model.predict(cnn_input, verbose=0)
except Exception:
# 最后尝试将输入重塑为2D
reshaped_input = cnn_input.reshape(1, -1)
cnn_pred = self.cnn_model.predict(reshaped_input)
# 处理预测结果
if isinstance(cnn_pred, list):
cnn_pred = cnn_pred[0]
# 提取概率值
if hasattr(cnn_pred, 'shape') and len(cnn_pred.shape) > 1 and cnn_pred.shape[1] > 1:
cnn_prob = cnn_pred[0][1]
else:
cnn_prob = cnn_pred[0][0] if hasattr(cnn_pred[0], '__getitem__') else cnn_pred[0]
except Exception as e:
print(f"CNN模型预测序列 {i+1} 时出错: {str(e)}")
# 出错时设置概率为0
cnn_prob = 0.0
# 投票模型预测
try:
# 确保投票模型输入是二维数组 (1, n_features)
voting_input = np.array([[gb_prob, cnn_prob]])
voting_prob = self.voting_model.predict_proba(voting_input)[0][1]
except Exception as e:
print(f"投票模型预测序列 {i+1} 时出错: {str(e)}")
# 出错时使用两个模型的平均值
voting_prob = (gb_prob + cnn_prob) / 2
results.append({
'GB_Probability': gb_prob,
'CNN_Probability': cnn_prob,
'Voting_Probability': voting_prob,
'33bp': seq33,
'399bp': seq399
})
except Exception as e:
print(f"处理第 {i+1} 个序列时出错: {str(e)}")
results.append({
'GB_Probability': 0.0,
'CNN_Probability': 0.0,
'Voting_Probability': 0.0,
'33bp': self._extract_center_sequence(seq399, target_length=self.gb_seq_length) if len(seq399) >= self.gb_seq_length else seq399,
'399bp': seq399
})
return pd.DataFrame(results)
except Exception as e:
raise Exception(f"区域预测过程出错: {str(e)}")
def _extract_center_sequence(self, sequence, target_length=33):
"""从序列中心位置提取指定长度的子序列"""
# 确保序列为字符串
sequence = str(sequence).upper()
# 如果序列长度小于目标长度,返回原序列
if len(sequence) <= target_length:
return sequence
# 计算中心位置
center = len(sequence) // 2
half_target = target_length // 2
# 提取中心序列
start = center - half_target
end = start + target_length
# 边界检查
if start < 0:
start = 0
end = target_length
elif end > len(sequence):
end = len(sequence)
start = end - target_length
return sequence[start:end]