FScanpy-package/FScanpy/predictor.py

427 lines
19 KiB
Python
Raw Normal View History

2025-03-18 11:21:54 +08:00
import os
import pickle
import numpy as np
import pandas as pd
from tensorflow.keras.models import load_model
from .features.sequence import SequenceFeatureExtractor
from .features.cnn_input import CNNInputProcessor
from .utils import extract_window_sequences
import matplotlib.pyplot as plt
import joblib
class PRFPredictor:
def __init__(self, model_dir=None):
"""
初始化PRF预测器
Args:
model_dir: 模型目录路径可选
"""
if model_dir is None:
from pkg_resources import resource_filename
model_dir = resource_filename('FScanpy', 'pretrained')
try:
# 加载模型
self.gb_model = self._load_pickle(os.path.join(model_dir, 'GradientBoosting_all.pkl'))
self.cnn_model = self._load_pickle(os.path.join(model_dir, 'BiLSTM-CNN_all.pkl'))
2025-05-29 14:40:01 +08:00
2025-03-18 11:21:54 +08:00
# 初始化特征提取器和CNN处理器使用与训练时相同的序列长度
self.gb_seq_length = 33 # HistGradientBoosting使用的序列长度
self.cnn_seq_length = 399 # BiLSTM-CNN使用的序列长度
# 初始化特征提取器和CNN输入处理器
self.feature_extractor = SequenceFeatureExtractor(seq_length=self.gb_seq_length)
self.cnn_processor = CNNInputProcessor(max_length=self.cnn_seq_length)
except FileNotFoundError as e:
raise FileNotFoundError(f"无法找到模型文件: {str(e)}")
except Exception as e:
raise Exception(f"加载模型出错: {str(e)}")
def _load_pickle(self, path):
return joblib.load(path)
def predict_single_position(self, fs_period, full_seq, gb_threshold=0.1):
'''
预测单个位置的PRF状态
Args:
fs_period: 33bp序列 (将根据gb_seq_length处理)
full_seq: 完整序列 (将根据cnn_seq_length处理)
gb_threshold: GB模型的概率阈值 (默认为0.1)
Returns:
dict: 包含预测概率的字典
'''
try:
# 处理序列长度
if len(fs_period) > self.gb_seq_length:
fs_period = self.feature_extractor.trim_sequence(fs_period, self.gb_seq_length)
# GB模型预测 - 确保输入是二维数组
try:
gb_features = self.feature_extractor.extract_features(fs_period)
# 检查特征结构并确保是一维数组
if isinstance(gb_features, np.ndarray):
# 如果是多维数组,进行扁平化处理
if gb_features.ndim > 1:
print(f"警告: 特征是{gb_features.ndim}维数组,进行扁平化处理")
gb_features = gb_features.flatten()
# 明确将特征转换为二维数组,正确形状为(1, n_features)
gb_features_2d = np.array([gb_features])
# 再次检查维度
if gb_features_2d.ndim != 2:
raise ValueError(f"处理后特征仍为{gb_features_2d.ndim}维,需要二维数组")
gb_prob = self.gb_model.predict_proba(gb_features_2d)[0][1]
except Exception as e:
print(f"GB模型预测时出错: {str(e)}")
# 出错时设置概率为0
gb_prob = 0.0
# 如果GB概率低于阈值则跳过CNN模型
if gb_prob < gb_threshold:
return {
'GB_Probability': gb_prob,
'CNN_Probability': 0.0,
'Voting_Probability': 0.0
}
# CNN模型预测
try:
# 首先检查CNN模型的类型 - 通过尝试识别模型类型
is_sklearn_model = False
# 检测模型类型的方法
if hasattr(self.cnn_model, 'predict_proba'):
# 这可能是一个scikit-learn模型
is_sklearn_model = True
if is_sklearn_model:
# 如果是sklearn模型 (如HistGradientBoostingClassifier)使用与GB相同的特征提取
# 为CNN模型使用相同的特征提取方法但从399bp序列中提取
cnn_features = self.feature_extractor.extract_features(full_seq)
if isinstance(cnn_features, np.ndarray) and cnn_features.ndim > 1:
cnn_features = cnn_features.flatten()
# 转为二维数组
cnn_features_2d = np.array([cnn_features])
cnn_pred = self.cnn_model.predict_proba(cnn_features_2d)
cnn_prob = cnn_pred[0][1]
else:
# 假设是深度学习模型,需要三维输入
cnn_input = self.cnn_processor.prepare_sequence(full_seq)
# 尝试不同的预测方法
try:
# 先尝试不带参数
cnn_pred = self.cnn_model.predict(cnn_input)
except TypeError:
try:
# 再尝试带verbose参数
cnn_pred = self.cnn_model.predict(cnn_input, verbose=0)
except Exception:
# 最后尝试将输入重塑为2D
reshaped_input = cnn_input.reshape(1, -1)
cnn_pred = self.cnn_model.predict(reshaped_input)
# 处理预测结果
if isinstance(cnn_pred, list):
cnn_pred = cnn_pred[0]
# 提取概率值
if hasattr(cnn_pred, 'shape') and len(cnn_pred.shape) > 1 and cnn_pred.shape[1] > 1:
cnn_prob = cnn_pred[0][1]
else:
cnn_prob = cnn_pred[0][0] if hasattr(cnn_pred[0], '__getitem__') else cnn_pred[0]
except Exception as e:
print(f"CNN模型预测时出错: {str(e)}")
# 出错时设置概率为0
cnn_prob = 0.0
2025-05-29 14:40:01 +08:00
# 使用4:6的加权平均替代投票模型
2025-03-18 11:21:54 +08:00
try:
2025-05-29 14:40:01 +08:00
voting_prob = 0.4 * gb_prob + 0.6 * cnn_prob
2025-03-18 11:21:54 +08:00
except Exception as e:
2025-05-29 14:40:01 +08:00
print(f"计算加权平均时出错: {str(e)}")
# 出错时使用简单平均
2025-03-18 11:21:54 +08:00
voting_prob = (gb_prob + cnn_prob) / 2
return {
'GB_Probability': gb_prob,
'CNN_Probability': cnn_prob,
'Voting_Probability': voting_prob
}
except Exception as e:
raise Exception(f"预测过程出错: {str(e)}")
def predict_full(self, sequence, window_size=3, gb_threshold=0.1, plot=False):
"""
预测完整序列中的PRF位点
Args:
sequence: 输入DNA序列
window_size: 滑动窗口大小 (默认为3)
gb_threshold: GB模型概率阈值 (默认为0.1)
plot: 是否绘制预测结果图表 (默认为False)
Returns:
if plot=False:
pd.DataFrame: 包含预测结果的DataFrame
if plot=True:
tuple: (pd.DataFrame, matplotlib.figure.Figure)
"""
if window_size < 1:
raise ValueError("窗口大小必须大于等于1")
if gb_threshold < 0:
raise ValueError("GB阈值必须大于等于0")
results = []
try:
# 确保序列为字符串并转换为大写
sequence = str(sequence).upper()
# 滑动窗口预测
for pos in range(0, len(sequence) - 2, window_size):
# 提取窗口序列 - 使用与训练时相同的窗口大小
fs_period, full_seq = extract_window_sequences(sequence, pos)
if fs_period is None or full_seq is None:
continue
# 预测并记录结果
pred = self.predict_single_position(fs_period, full_seq, gb_threshold)
pred.update({
'Position': pos,
'Codon': sequence[pos:pos+3],
'33bp': fs_period,
'399bp': full_seq
})
results.append(pred)
# 创建结果DataFrame
results_df = pd.DataFrame(results)
# 如需绘图
if plot:
# 创建图形
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 10), height_ratios=[2, 1])
# 绘制折线图
ax1.plot(results_df['Position'], results_df['GB_Probability'],
label='GB模型', alpha=0.7, linewidth=1.5)
ax1.plot(results_df['Position'], results_df['CNN_Probability'],
label='CNN模型', alpha=0.7, linewidth=1.5)
ax1.plot(results_df['Position'], results_df['Voting_Probability'],
label='投票模型', linewidth=2, color='red')
ax1.set_xlabel('序列位置')
ax1.set_ylabel('移码概率')
ax1.set_title('移码预测概率')
ax1.legend()
ax1.grid(True, alpha=0.3)
# 准备热图数据
positions = results_df['Position'].values
probabilities = results_df['Voting_Probability'].values
# 创建热图矩阵
heatmap_matrix = np.zeros((1, len(positions)))
heatmap_matrix[0, :] = probabilities
# 绘制热图
im = ax2.imshow(heatmap_matrix, aspect='auto', cmap='YlOrRd',
extent=[min(positions), max(positions), 0, 1])
# 添加颜色条
cbar = plt.colorbar(im, ax=ax2)
cbar.set_label('移码概率')
# 设置热图轴标签
ax2.set_xlabel('序列位置')
ax2.set_title('移码概率热图')
ax2.set_yticks([])
# 调整布局
plt.tight_layout()
return results_df, fig
return results_df
except Exception as e:
raise Exception(f"序列预测过程出错: {str(e)}")
def predict_region(self, seq, gb_threshold=0.1):
'''
预测区域序列
Args:
seq: 399bp序列或包含399bp序列的DataFrame/Series
gb_threshold: GB模型概率阈值 (默认为0.1)
Returns:
DataFrame: 包含所有序列预测概率的DataFrame
'''
try:
# 如果输入是DataFrame或Series转换为列表
if isinstance(seq, (pd.DataFrame, pd.Series)):
seq = seq.tolist()
# 如果输入是单个字符串,转换为列表
if isinstance(seq, str):
seq = [seq]
results = []
for i, seq399 in enumerate(seq):
try:
# 从399bp序列中截取中心的33bp (GB模型使用)
seq33 = self._extract_center_sequence(seq399, target_length=self.gb_seq_length)
# GB模型预测 - 确保输入是二维数组
try:
gb_features = self.feature_extractor.extract_features(seq33)
# 检查特征结构并确保是一维数组
if isinstance(gb_features, np.ndarray):
# 如果是多维数组,进行扁平化处理
if gb_features.ndim > 1:
print(f"警告: 序列 {i+1} 的特征是{gb_features.ndim}维数组,进行扁平化处理")
gb_features = gb_features.flatten()
# 明确将特征转换为二维数组,正确形状为(1, n_features)
gb_features_2d = np.array([gb_features])
# 再次检查维度
if gb_features_2d.ndim != 2:
raise ValueError(f"处理后特征仍为{gb_features_2d.ndim}维,需要二维数组")
gb_prob = self.gb_model.predict_proba(gb_features_2d)[0][1]
except Exception as e:
print(f"GB模型预测序列 {i+1} 时出错: {str(e)}")
# 出错时设置概率为0
gb_prob = 0.0
# 如果GB概率低于阈值添加低概率结果
if gb_prob < gb_threshold:
results.append({
'GB_Probability': gb_prob,
'CNN_Probability': 0.0,
'Voting_Probability': 0.0,
'33bp': seq33,
'399bp': seq399
})
continue
# CNN模型预测
try:
# 首先检查CNN模型的类型 - 通过尝试识别模型类型
is_sklearn_model = False
# 检测模型类型的方法
if hasattr(self.cnn_model, 'predict_proba'):
# 这可能是一个scikit-learn模型
is_sklearn_model = True
if is_sklearn_model:
# 如果是sklearn模型 (如HistGradientBoostingClassifier)使用与GB相同的特征提取
# 为CNN模型使用相同的特征提取方法但从399bp序列中提取
cnn_features = self.feature_extractor.extract_features(seq399)
if isinstance(cnn_features, np.ndarray) and cnn_features.ndim > 1:
cnn_features = cnn_features.flatten()
# 转为二维数组
cnn_features_2d = np.array([cnn_features])
cnn_pred = self.cnn_model.predict_proba(cnn_features_2d)
cnn_prob = cnn_pred[0][1]
else:
# 假设是深度学习模型,需要三维输入
cnn_input = self.cnn_processor.prepare_sequence(seq399)
# 尝试不同的预测方法
try:
# 先尝试不带参数
cnn_pred = self.cnn_model.predict(cnn_input)
except TypeError:
try:
# 再尝试带verbose参数
cnn_pred = self.cnn_model.predict(cnn_input, verbose=0)
except Exception:
# 最后尝试将输入重塑为2D
reshaped_input = cnn_input.reshape(1, -1)
cnn_pred = self.cnn_model.predict(reshaped_input)
# 处理预测结果
if isinstance(cnn_pred, list):
cnn_pred = cnn_pred[0]
# 提取概率值
if hasattr(cnn_pred, 'shape') and len(cnn_pred.shape) > 1 and cnn_pred.shape[1] > 1:
cnn_prob = cnn_pred[0][1]
else:
cnn_prob = cnn_pred[0][0] if hasattr(cnn_pred[0], '__getitem__') else cnn_pred[0]
except Exception as e:
print(f"CNN模型预测序列 {i+1} 时出错: {str(e)}")
# 出错时设置概率为0
cnn_prob = 0.0
2025-05-29 14:40:01 +08:00
# 使用4:6的加权平均替代投票模型
2025-03-18 11:21:54 +08:00
try:
2025-05-29 14:40:01 +08:00
voting_prob = 0.4 * gb_prob + 0.6 * cnn_prob
2025-03-18 11:21:54 +08:00
except Exception as e:
2025-05-29 14:40:01 +08:00
print(f"计算加权平均时出错: {str(e)}")
# 出错时使用简单平均
2025-03-18 11:21:54 +08:00
voting_prob = (gb_prob + cnn_prob) / 2
results.append({
'GB_Probability': gb_prob,
'CNN_Probability': cnn_prob,
'Voting_Probability': voting_prob,
'33bp': seq33,
'399bp': seq399
})
except Exception as e:
print(f"处理第 {i+1} 个序列时出错: {str(e)}")
results.append({
'GB_Probability': 0.0,
'CNN_Probability': 0.0,
'Voting_Probability': 0.0,
'33bp': self._extract_center_sequence(seq399, target_length=self.gb_seq_length) if len(seq399) >= self.gb_seq_length else seq399,
'399bp': seq399
})
return pd.DataFrame(results)
except Exception as e:
raise Exception(f"区域预测过程出错: {str(e)}")
def _extract_center_sequence(self, sequence, target_length=33):
"""从序列中心位置提取指定长度的子序列"""
# 确保序列为字符串
sequence = str(sequence).upper()
# 如果序列长度小于目标长度,返回原序列
if len(sequence) <= target_length:
return sequence
# 计算中心位置
center = len(sequence) // 2
half_target = target_length // 2
# 提取中心序列
start = center - half_target
end = start + target_length
# 边界检查
if start < 0:
start = 0
end = target_length
elif end > len(sequence):
end = len(sequence)
start = end - target_length
return sequence[start:end]