434 lines
20 KiB
Python
434 lines
20 KiB
Python
import os
|
||
import pickle
|
||
import numpy as np
|
||
import pandas as pd
|
||
from tensorflow.keras.models import load_model
|
||
from .features.sequence import SequenceFeatureExtractor
|
||
from .features.cnn_input import CNNInputProcessor
|
||
from .utils import extract_window_sequences
|
||
import matplotlib.pyplot as plt
|
||
import joblib
|
||
|
||
|
||
class PRFPredictor:
|
||
|
||
def __init__(self, model_dir=None):
|
||
"""
|
||
初始化PRF预测器
|
||
|
||
Args:
|
||
model_dir: 模型目录路径(可选)
|
||
"""
|
||
if model_dir is None:
|
||
from pkg_resources import resource_filename
|
||
model_dir = resource_filename('FScanpy', 'pretrained')
|
||
|
||
try:
|
||
# 加载模型
|
||
self.gb_model = self._load_pickle(os.path.join(model_dir, 'GradientBoosting_all.pkl'))
|
||
|
||
self.cnn_model = self._load_pickle(os.path.join(model_dir, 'BiLSTM-CNN_all.pkl'))
|
||
|
||
self.voting_model = self._load_pickle(os.path.join(model_dir, 'Voting_all.pkl'))
|
||
|
||
# 初始化特征提取器和CNN处理器,使用与训练时相同的序列长度
|
||
self.gb_seq_length = 33 # HistGradientBoosting使用的序列长度
|
||
self.cnn_seq_length = 399 # BiLSTM-CNN使用的序列长度
|
||
|
||
# 初始化特征提取器和CNN输入处理器
|
||
self.feature_extractor = SequenceFeatureExtractor(seq_length=self.gb_seq_length)
|
||
self.cnn_processor = CNNInputProcessor(max_length=self.cnn_seq_length)
|
||
|
||
except FileNotFoundError as e:
|
||
raise FileNotFoundError(f"无法找到模型文件: {str(e)}")
|
||
except Exception as e:
|
||
raise Exception(f"加载模型出错: {str(e)}")
|
||
|
||
def _load_pickle(self, path):
|
||
return joblib.load(path)
|
||
|
||
def predict_single_position(self, fs_period, full_seq, gb_threshold=0.1):
|
||
'''
|
||
预测单个位置的PRF状态
|
||
|
||
Args:
|
||
fs_period: 33bp序列 (将根据gb_seq_length处理)
|
||
full_seq: 完整序列 (将根据cnn_seq_length处理)
|
||
gb_threshold: GB模型的概率阈值 (默认为0.1)
|
||
Returns:
|
||
dict: 包含预测概率的字典
|
||
'''
|
||
try:
|
||
# 处理序列长度
|
||
if len(fs_period) > self.gb_seq_length:
|
||
fs_period = self.feature_extractor.trim_sequence(fs_period, self.gb_seq_length)
|
||
|
||
# GB模型预测 - 确保输入是二维数组
|
||
try:
|
||
gb_features = self.feature_extractor.extract_features(fs_period)
|
||
|
||
# 检查特征结构并确保是一维数组
|
||
if isinstance(gb_features, np.ndarray):
|
||
# 如果是多维数组,进行扁平化处理
|
||
if gb_features.ndim > 1:
|
||
print(f"警告: 特征是{gb_features.ndim}维数组,进行扁平化处理")
|
||
gb_features = gb_features.flatten()
|
||
|
||
# 明确将特征转换为二维数组,正确形状为(1, n_features)
|
||
gb_features_2d = np.array([gb_features])
|
||
|
||
# 再次检查维度
|
||
if gb_features_2d.ndim != 2:
|
||
raise ValueError(f"处理后特征仍为{gb_features_2d.ndim}维,需要二维数组")
|
||
|
||
gb_prob = self.gb_model.predict_proba(gb_features_2d)[0][1]
|
||
except Exception as e:
|
||
print(f"GB模型预测时出错: {str(e)}")
|
||
# 出错时设置概率为0
|
||
gb_prob = 0.0
|
||
|
||
# 如果GB概率低于阈值,则跳过CNN模型
|
||
if gb_prob < gb_threshold:
|
||
return {
|
||
'GB_Probability': gb_prob,
|
||
'CNN_Probability': 0.0,
|
||
'Voting_Probability': 0.0
|
||
}
|
||
|
||
# CNN模型预测
|
||
try:
|
||
# 首先检查CNN模型的类型 - 通过尝试识别模型类型
|
||
is_sklearn_model = False
|
||
|
||
# 检测模型类型的方法
|
||
if hasattr(self.cnn_model, 'predict_proba'):
|
||
# 这可能是一个scikit-learn模型
|
||
is_sklearn_model = True
|
||
|
||
if is_sklearn_model:
|
||
# 如果是sklearn模型 (如HistGradientBoostingClassifier),使用与GB相同的特征提取
|
||
# 为CNN模型使用相同的特征提取方法,但从399bp序列中提取
|
||
cnn_features = self.feature_extractor.extract_features(full_seq)
|
||
if isinstance(cnn_features, np.ndarray) and cnn_features.ndim > 1:
|
||
cnn_features = cnn_features.flatten()
|
||
# 转为二维数组
|
||
cnn_features_2d = np.array([cnn_features])
|
||
cnn_pred = self.cnn_model.predict_proba(cnn_features_2d)
|
||
cnn_prob = cnn_pred[0][1]
|
||
else:
|
||
# 假设是深度学习模型,需要三维输入
|
||
cnn_input = self.cnn_processor.prepare_sequence(full_seq)
|
||
# 尝试不同的预测方法
|
||
try:
|
||
# 先尝试不带参数
|
||
cnn_pred = self.cnn_model.predict(cnn_input)
|
||
except TypeError:
|
||
try:
|
||
# 再尝试带verbose参数
|
||
cnn_pred = self.cnn_model.predict(cnn_input, verbose=0)
|
||
except Exception:
|
||
# 最后尝试将输入重塑为2D
|
||
reshaped_input = cnn_input.reshape(1, -1)
|
||
cnn_pred = self.cnn_model.predict(reshaped_input)
|
||
|
||
# 处理预测结果
|
||
if isinstance(cnn_pred, list):
|
||
cnn_pred = cnn_pred[0]
|
||
|
||
# 提取概率值
|
||
if hasattr(cnn_pred, 'shape') and len(cnn_pred.shape) > 1 and cnn_pred.shape[1] > 1:
|
||
cnn_prob = cnn_pred[0][1]
|
||
else:
|
||
cnn_prob = cnn_pred[0][0] if hasattr(cnn_pred[0], '__getitem__') else cnn_pred[0]
|
||
except Exception as e:
|
||
print(f"CNN模型预测时出错: {str(e)}")
|
||
# 出错时设置概率为0
|
||
cnn_prob = 0.0
|
||
|
||
# 投票模型预测
|
||
try:
|
||
# 确保投票模型输入是二维数组 (1, n_features)
|
||
voting_input = np.array([[gb_prob, cnn_prob]])
|
||
voting_prob = self.voting_model.predict_proba(voting_input)[0][1]
|
||
except Exception as e:
|
||
print(f"投票模型预测时出错: {str(e)}")
|
||
# 出错时使用两个模型的平均值
|
||
voting_prob = (gb_prob + cnn_prob) / 2
|
||
|
||
return {
|
||
'GB_Probability': gb_prob,
|
||
'CNN_Probability': cnn_prob,
|
||
'Voting_Probability': voting_prob
|
||
}
|
||
|
||
except Exception as e:
|
||
raise Exception(f"预测过程出错: {str(e)}")
|
||
|
||
def predict_full(self, sequence, window_size=3, gb_threshold=0.1, plot=False):
|
||
"""
|
||
预测完整序列中的PRF位点
|
||
|
||
Args:
|
||
sequence: 输入DNA序列
|
||
window_size: 滑动窗口大小 (默认为3)
|
||
gb_threshold: GB模型概率阈值 (默认为0.1)
|
||
plot: 是否绘制预测结果图表 (默认为False)
|
||
|
||
Returns:
|
||
if plot=False:
|
||
pd.DataFrame: 包含预测结果的DataFrame
|
||
if plot=True:
|
||
tuple: (pd.DataFrame, matplotlib.figure.Figure)
|
||
"""
|
||
if window_size < 1:
|
||
raise ValueError("窗口大小必须大于等于1")
|
||
if gb_threshold < 0:
|
||
raise ValueError("GB阈值必须大于等于0")
|
||
|
||
results = []
|
||
|
||
try:
|
||
# 确保序列为字符串并转换为大写
|
||
sequence = str(sequence).upper()
|
||
|
||
# 滑动窗口预测
|
||
for pos in range(0, len(sequence) - 2, window_size):
|
||
# 提取窗口序列 - 使用与训练时相同的窗口大小
|
||
fs_period, full_seq = extract_window_sequences(sequence, pos)
|
||
|
||
if fs_period is None or full_seq is None:
|
||
continue
|
||
|
||
# 预测并记录结果
|
||
pred = self.predict_single_position(fs_period, full_seq, gb_threshold)
|
||
pred.update({
|
||
'Position': pos,
|
||
'Codon': sequence[pos:pos+3],
|
||
'33bp': fs_period,
|
||
'399bp': full_seq
|
||
})
|
||
results.append(pred)
|
||
|
||
# 创建结果DataFrame
|
||
results_df = pd.DataFrame(results)
|
||
|
||
# 如需绘图
|
||
if plot:
|
||
# 创建图形
|
||
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 10), height_ratios=[2, 1])
|
||
|
||
# 绘制折线图
|
||
ax1.plot(results_df['Position'], results_df['GB_Probability'],
|
||
label='GB模型', alpha=0.7, linewidth=1.5)
|
||
ax1.plot(results_df['Position'], results_df['CNN_Probability'],
|
||
label='CNN模型', alpha=0.7, linewidth=1.5)
|
||
ax1.plot(results_df['Position'], results_df['Voting_Probability'],
|
||
label='投票模型', linewidth=2, color='red')
|
||
|
||
ax1.set_xlabel('序列位置')
|
||
ax1.set_ylabel('移码概率')
|
||
ax1.set_title('移码预测概率')
|
||
ax1.legend()
|
||
ax1.grid(True, alpha=0.3)
|
||
|
||
# 准备热图数据
|
||
positions = results_df['Position'].values
|
||
probabilities = results_df['Voting_Probability'].values
|
||
|
||
# 创建热图矩阵
|
||
heatmap_matrix = np.zeros((1, len(positions)))
|
||
heatmap_matrix[0, :] = probabilities
|
||
|
||
# 绘制热图
|
||
im = ax2.imshow(heatmap_matrix, aspect='auto', cmap='YlOrRd',
|
||
extent=[min(positions), max(positions), 0, 1])
|
||
|
||
# 添加颜色条
|
||
cbar = plt.colorbar(im, ax=ax2)
|
||
cbar.set_label('移码概率')
|
||
|
||
# 设置热图轴标签
|
||
ax2.set_xlabel('序列位置')
|
||
ax2.set_title('移码概率热图')
|
||
ax2.set_yticks([])
|
||
|
||
# 调整布局
|
||
plt.tight_layout()
|
||
|
||
return results_df, fig
|
||
|
||
return results_df
|
||
|
||
except Exception as e:
|
||
raise Exception(f"序列预测过程出错: {str(e)}")
|
||
|
||
def predict_region(self, seq, gb_threshold=0.1):
|
||
'''
|
||
预测区域序列
|
||
|
||
Args:
|
||
seq: 399bp序列或包含399bp序列的DataFrame/Series
|
||
gb_threshold: GB模型概率阈值 (默认为0.1)
|
||
|
||
Returns:
|
||
DataFrame: 包含所有序列预测概率的DataFrame
|
||
'''
|
||
try:
|
||
# 如果输入是DataFrame或Series,转换为列表
|
||
if isinstance(seq, (pd.DataFrame, pd.Series)):
|
||
seq = seq.tolist()
|
||
|
||
# 如果输入是单个字符串,转换为列表
|
||
if isinstance(seq, str):
|
||
seq = [seq]
|
||
|
||
results = []
|
||
for i, seq399 in enumerate(seq):
|
||
try:
|
||
# 从399bp序列中截取中心的33bp (GB模型使用)
|
||
seq33 = self._extract_center_sequence(seq399, target_length=self.gb_seq_length)
|
||
|
||
# GB模型预测 - 确保输入是二维数组
|
||
try:
|
||
gb_features = self.feature_extractor.extract_features(seq33)
|
||
|
||
# 检查特征结构并确保是一维数组
|
||
if isinstance(gb_features, np.ndarray):
|
||
# 如果是多维数组,进行扁平化处理
|
||
if gb_features.ndim > 1:
|
||
print(f"警告: 序列 {i+1} 的特征是{gb_features.ndim}维数组,进行扁平化处理")
|
||
gb_features = gb_features.flatten()
|
||
|
||
# 明确将特征转换为二维数组,正确形状为(1, n_features)
|
||
gb_features_2d = np.array([gb_features])
|
||
|
||
# 再次检查维度
|
||
if gb_features_2d.ndim != 2:
|
||
raise ValueError(f"处理后特征仍为{gb_features_2d.ndim}维,需要二维数组")
|
||
|
||
gb_prob = self.gb_model.predict_proba(gb_features_2d)[0][1]
|
||
except Exception as e:
|
||
print(f"GB模型预测序列 {i+1} 时出错: {str(e)}")
|
||
# 出错时设置概率为0
|
||
gb_prob = 0.0
|
||
|
||
# 如果GB概率低于阈值,添加低概率结果
|
||
if gb_prob < gb_threshold:
|
||
results.append({
|
||
'GB_Probability': gb_prob,
|
||
'CNN_Probability': 0.0,
|
||
'Voting_Probability': 0.0,
|
||
'33bp': seq33,
|
||
'399bp': seq399
|
||
})
|
||
continue
|
||
|
||
# CNN模型预测
|
||
try:
|
||
# 首先检查CNN模型的类型 - 通过尝试识别模型类型
|
||
is_sklearn_model = False
|
||
|
||
# 检测模型类型的方法
|
||
if hasattr(self.cnn_model, 'predict_proba'):
|
||
# 这可能是一个scikit-learn模型
|
||
is_sklearn_model = True
|
||
|
||
if is_sklearn_model:
|
||
# 如果是sklearn模型 (如HistGradientBoostingClassifier),使用与GB相同的特征提取
|
||
# 为CNN模型使用相同的特征提取方法,但从399bp序列中提取
|
||
cnn_features = self.feature_extractor.extract_features(seq399)
|
||
if isinstance(cnn_features, np.ndarray) and cnn_features.ndim > 1:
|
||
cnn_features = cnn_features.flatten()
|
||
# 转为二维数组
|
||
cnn_features_2d = np.array([cnn_features])
|
||
cnn_pred = self.cnn_model.predict_proba(cnn_features_2d)
|
||
cnn_prob = cnn_pred[0][1]
|
||
else:
|
||
# 假设是深度学习模型,需要三维输入
|
||
cnn_input = self.cnn_processor.prepare_sequence(seq399)
|
||
# 尝试不同的预测方法
|
||
try:
|
||
# 先尝试不带参数
|
||
cnn_pred = self.cnn_model.predict(cnn_input)
|
||
except TypeError:
|
||
try:
|
||
# 再尝试带verbose参数
|
||
cnn_pred = self.cnn_model.predict(cnn_input, verbose=0)
|
||
except Exception:
|
||
# 最后尝试将输入重塑为2D
|
||
reshaped_input = cnn_input.reshape(1, -1)
|
||
cnn_pred = self.cnn_model.predict(reshaped_input)
|
||
|
||
# 处理预测结果
|
||
if isinstance(cnn_pred, list):
|
||
cnn_pred = cnn_pred[0]
|
||
|
||
# 提取概率值
|
||
if hasattr(cnn_pred, 'shape') and len(cnn_pred.shape) > 1 and cnn_pred.shape[1] > 1:
|
||
cnn_prob = cnn_pred[0][1]
|
||
else:
|
||
cnn_prob = cnn_pred[0][0] if hasattr(cnn_pred[0], '__getitem__') else cnn_pred[0]
|
||
except Exception as e:
|
||
print(f"CNN模型预测序列 {i+1} 时出错: {str(e)}")
|
||
# 出错时设置概率为0
|
||
cnn_prob = 0.0
|
||
|
||
# 投票模型预测
|
||
try:
|
||
# 确保投票模型输入是二维数组 (1, n_features)
|
||
voting_input = np.array([[gb_prob, cnn_prob]])
|
||
voting_prob = self.voting_model.predict_proba(voting_input)[0][1]
|
||
except Exception as e:
|
||
print(f"投票模型预测序列 {i+1} 时出错: {str(e)}")
|
||
# 出错时使用两个模型的平均值
|
||
voting_prob = (gb_prob + cnn_prob) / 2
|
||
|
||
results.append({
|
||
'GB_Probability': gb_prob,
|
||
'CNN_Probability': cnn_prob,
|
||
'Voting_Probability': voting_prob,
|
||
'33bp': seq33,
|
||
'399bp': seq399
|
||
})
|
||
|
||
except Exception as e:
|
||
print(f"处理第 {i+1} 个序列时出错: {str(e)}")
|
||
results.append({
|
||
'GB_Probability': 0.0,
|
||
'CNN_Probability': 0.0,
|
||
'Voting_Probability': 0.0,
|
||
'33bp': self._extract_center_sequence(seq399, target_length=self.gb_seq_length) if len(seq399) >= self.gb_seq_length else seq399,
|
||
'399bp': seq399
|
||
})
|
||
|
||
return pd.DataFrame(results)
|
||
|
||
except Exception as e:
|
||
raise Exception(f"区域预测过程出错: {str(e)}")
|
||
|
||
def _extract_center_sequence(self, sequence, target_length=33):
|
||
"""从序列中心位置提取指定长度的子序列"""
|
||
# 确保序列为字符串
|
||
sequence = str(sequence).upper()
|
||
|
||
# 如果序列长度小于目标长度,返回原序列
|
||
if len(sequence) <= target_length:
|
||
return sequence
|
||
|
||
# 计算中心位置
|
||
center = len(sequence) // 2
|
||
half_target = target_length // 2
|
||
|
||
# 提取中心序列
|
||
start = center - half_target
|
||
end = start + target_length
|
||
|
||
# 边界检查
|
||
if start < 0:
|
||
start = 0
|
||
end = target_length
|
||
elif end > len(sequence):
|
||
end = len(sequence)
|
||
start = end - target_length
|
||
|
||
return sequence[start:end] |