FScanpy-package/FScanpy/predictor.py

487 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import pickle
import numpy as np
import pandas as pd
from tensorflow.keras.models import load_model
from .features.sequence import SequenceFeatureExtractor
from .features.cnn_input import CNNInputProcessor
from .utils import extract_window_sequences
import matplotlib.pyplot as plt
import joblib
class PRFPredictor:
def __init__(self, model_dir=None):
"""
初始化PRF预测器
Args:
model_dir: 模型目录路径(可选)
"""
if model_dir is None:
from pkg_resources import resource_filename
model_dir = resource_filename('FScanpy', 'pretrained')
try:
# 加载模型 - 使用新的命名约定
self.short_model = self._load_pickle(os.path.join(model_dir, 'short.pkl')) # HistGB模型
self.long_model = self._load_pickle(os.path.join(model_dir, 'long.pkl')) # BiLSTM-CNN模型
# 初始化特征提取器和CNN处理器使用与训练时相同的序列长度
self.short_seq_length = 33 # HistGB使用的序列长度
self.long_seq_length = 399 # BiLSTM-CNN使用的序列长度
# 初始化特征提取器和CNN输入处理器
self.feature_extractor = SequenceFeatureExtractor(seq_length=self.short_seq_length)
self.cnn_processor = CNNInputProcessor(max_length=self.long_seq_length)
# 检测模型类型以优化预测性能
self._detect_model_types()
except FileNotFoundError as e:
raise FileNotFoundError(f"无法找到模型文件: {str(e)}。请确保模型文件 'short.pkl''long.pkl' 存在于 {model_dir}")
except Exception as e:
raise Exception(f"加载模型出错: {str(e)}")
def _load_pickle(self, path):
"""安全加载pickle文件"""
try:
return joblib.load(path)
except Exception as e:
raise FileNotFoundError(f"无法加载模型文件 {path}: {str(e)}")
def _detect_model_types(self):
"""检测模型类型以优化预测性能"""
self.short_is_sklearn = hasattr(self.short_model, 'predict_proba')
self.long_is_sklearn = hasattr(self.long_model, 'predict_proba')
def _predict_model(self, model, features, is_sklearn, seq_length):
"""统一的模型预测方法"""
try:
if is_sklearn:
# sklearn模型使用特征向量
if isinstance(features, np.ndarray) and features.ndim > 1:
features = features.flatten()
features_2d = np.array([features])
pred = model.predict_proba(features_2d)
return pred[0][1]
else:
# 深度学习模型
if seq_length == self.long_seq_length:
# 对于长序列使用CNN处理器
model_input = self.cnn_processor.prepare_sequence(features)
else:
# 对于短序列,转换为数值编码
base_to_num = {'A': 1, 'T': 2, 'G': 3, 'C': 4, 'N': 0}
seq_numeric = [base_to_num.get(base, 0) for base in features.upper()]
model_input = np.array(seq_numeric).reshape(1, len(seq_numeric), 1)
# 统一的预测调用
try:
pred = model.predict(model_input, verbose=0)
except TypeError:
pred = model.predict(model_input)
# 处理预测结果
if isinstance(pred, list):
pred = pred[0]
if hasattr(pred, 'shape') and len(pred.shape) > 1 and pred.shape[1] > 1:
return pred[0][1]
else:
return pred[0][0] if hasattr(pred[0], '__getitem__') else pred[0]
except Exception as e:
raise Exception(f"模型预测失败: {str(e)}")
def predict_single_position(self, fs_period, full_seq, short_threshold=0.1, ensemble_weight=0.4):
'''
预测单个位置的PRF状态
Args:
fs_period: 33bp序列 (short模型使用)
full_seq: 完整序列 (long模型使用)
short_threshold: short模型的概率阈值 (默认为0.1)
ensemble_weight: short模型在集成中的权重 (默认为0.4long权重为0.6)
Returns:
dict: 包含预测概率的字典
'''
try:
# 验证权重参数
if not (0.0 <= ensemble_weight <= 1.0):
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
long_weight = 1.0 - ensemble_weight
# 处理序列长度
if len(fs_period) > self.short_seq_length:
fs_period = self.feature_extractor.trim_sequence(fs_period, self.short_seq_length)
# Short模型预测 (HistGB)
try:
if self.short_is_sklearn:
short_features = self.feature_extractor.extract_features(fs_period)
short_prob = self._predict_model(self.short_model, short_features, True, self.short_seq_length)
else:
short_prob = self._predict_model(self.short_model, fs_period, False, self.short_seq_length)
except Exception as e:
print(f"Short模型预测时出错: {str(e)}")
short_prob = 0.0
# 如果short概率低于阈值则跳过long模型
if short_prob < short_threshold:
return {
'Short_Probability': short_prob,
'Long_Probability': 0.0,
'Ensemble_Probability': 0.0,
'Ensemble_Weights': f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}'
}
# Long模型预测 (BiLSTM-CNN)
try:
if self.long_is_sklearn:
long_features = self.feature_extractor.extract_features(full_seq)
long_prob = self._predict_model(self.long_model, long_features, True, self.long_seq_length)
else:
long_prob = self._predict_model(self.long_model, full_seq, False, self.long_seq_length)
except Exception as e:
print(f"Long模型预测时出错: {str(e)}")
long_prob = 0.0
# 计算集成概率
try:
ensemble_prob = ensemble_weight * short_prob + long_weight * long_prob
except Exception as e:
print(f"计算集成概率时出错: {str(e)}")
ensemble_prob = (short_prob + long_prob) / 2
return {
'Short_Probability': short_prob,
'Long_Probability': long_prob,
'Ensemble_Probability': ensemble_prob,
'Ensemble_Weights': f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}'
}
except Exception as e:
raise Exception(f"预测过程出错: {str(e)}")
def predict_sequence(self, sequence, window_size=3, short_threshold=0.1, ensemble_weight=0.4):
"""
预测完整序列中的PRF位点滑动窗口方法
Args:
sequence: 输入DNA序列
window_size: 滑动窗口大小 (默认为3)
short_threshold: short模型概率阈值 (默认为0.1)
ensemble_weight: short模型在集成中的权重 (默认为0.4)
Returns:
pd.DataFrame: 包含预测结果的DataFrame
"""
if window_size < 1:
raise ValueError("窗口大小必须大于等于1")
if short_threshold < 0:
raise ValueError("short模型阈值必须大于等于0")
if not (0.0 <= ensemble_weight <= 1.0):
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
results = []
long_weight = 1.0 - ensemble_weight
try:
# 确保序列为字符串并转换为大写
sequence = str(sequence).upper()
# 滑动窗口预测
for pos in range(0, len(sequence) - 2, window_size):
# 提取窗口序列
fs_period, full_seq = extract_window_sequences(sequence, pos)
if fs_period is None or full_seq is None:
continue
# 预测并记录结果
pred = self.predict_single_position(fs_period, full_seq, short_threshold, ensemble_weight)
pred.update({
'Position': pos,
'Codon': sequence[pos:pos+3],
'Short_Sequence': fs_period, # 更清晰的命名
'Long_Sequence': full_seq # 更清晰的命名
})
results.append(pred)
# 创建结果DataFrame
results_df = pd.DataFrame(results)
return results_df
except Exception as e:
raise Exception(f"序列预测过程出错: {str(e)}")
def plot_sequence_prediction(self, sequence, window_size=3, short_threshold=0.65,
long_threshold=0.8, ensemble_weight=0.4, title=None, save_path=None,
figsize=(12, 6), dpi=300):
"""
绘制序列预测结果的移码概率图
Args:
sequence: 输入DNA序列
window_size: 滑动窗口大小 (默认为3)
short_threshold: Short模型(HistGB)过滤阈值 (默认为0.65)
long_threshold: Long模型(BiLSTM-CNN)过滤阈值 (默认为0.8)
ensemble_weight: Short模型在集成中的权重 (默认为0.4)
title: 图片标题 (可选)
save_path: 保存路径 (可选,如果提供则保存图片)
figsize: 图片尺寸 (默认为(12, 6))
dpi: 图片分辨率 (默认为300)
Returns:
tuple: (pd.DataFrame, matplotlib.figure.Figure) 预测结果和图形对象
"""
try:
# 验证权重参数
if not (0.0 <= ensemble_weight <= 1.0):
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
long_weight = 1.0 - ensemble_weight
# 获取预测结果 - 使用新的方法名
results_df = self.predict_sequence(sequence, window_size=window_size,
short_threshold=0.1, ensemble_weight=ensemble_weight)
if results_df.empty:
raise ValueError("预测结果为空,请检查输入序列")
# 获取序列长度
seq_length = len(sequence)
# 计算显示宽度
prob_width = max(1, seq_length // 300) # 概率标记的宽度
# 创建图形,包含两个子图
fig = plt.figure(figsize=figsize)
gs = fig.add_gridspec(2, 1, height_ratios=[0.15, 1], hspace=0.3)
# 设置标题
if title:
fig.suptitle(title, y=0.95, fontsize=12)
else:
fig.suptitle(f'序列移码概率预测结果 (权重 {ensemble_weight:.1f}:{long_weight:.1f})', y=0.95, fontsize=12)
# 预测概率热图
ax0 = fig.add_subplot(gs[0])
prob_data = np.zeros((1, seq_length))
# 应用双重阈值过滤
for _, row in results_df.iterrows():
pos = int(row['Position'])
if (row['Short_Probability'] >= short_threshold and
row['Long_Probability'] >= long_threshold):
# 为每个满足阈值的位置设置概率值
start = max(0, pos - prob_width//2)
end = min(seq_length, pos + prob_width//2 + 1)
prob_data[0, start:end] = row['Ensemble_Probability']
im = ax0.imshow(prob_data, cmap='Reds', aspect='auto', vmin=0, vmax=1,
interpolation='nearest')
ax0.set_xticks([])
ax0.set_yticks([])
ax0.set_title(f'预测概率热图 (Short≥{short_threshold}, Long≥{long_threshold})',
pad=5, fontsize=10)
# 主图(条形图)
ax1 = fig.add_subplot(gs[1])
# 应用过滤阈值
filtered_probs = results_df['Ensemble_Probability'].copy()
mask = ((results_df['Short_Probability'] < short_threshold) |
(results_df['Long_Probability'] < long_threshold))
filtered_probs[mask] = 0
# 绘制条形图
bars = ax1.bar(results_df['Position'], filtered_probs,
alpha=0.7, color='darkred', width=max(1, window_size))
# 设置x轴刻度
step = max(seq_length // 10, 50)
x_ticks = np.arange(0, seq_length, step)
ax1.set_xticks(x_ticks)
ax1.tick_params(axis='x', rotation=45)
# 设置标签和标题
ax1.set_xlabel('序列位置 (bp)', fontsize=10)
ax1.set_ylabel('移码概率', fontsize=10)
ax1.set_title(f'移码概率分布 (集成权重 {ensemble_weight:.1f}:{long_weight:.1f})', fontsize=11)
# 设置y轴范围
ax1.set_ylim(0, 1)
# 添加网格
ax1.grid(True, alpha=0.3)
# 添加阈值和权重说明
info_text = (f'过滤阈值: Short≥{short_threshold}, Long≥{long_threshold}\n'
f'集成权重: Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}')
ax1.text(0.02, 0.95, info_text, transform=ax1.transAxes,
fontsize=9, verticalalignment='top',
bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8))
# 确保所有子图的x轴范围一致
for ax in [ax0, ax1]:
ax.set_xlim(-1, seq_length)
# 调整布局
plt.tight_layout()
# 如果提供了保存路径,则保存图片
if save_path:
plt.savefig(save_path, dpi=dpi, bbox_inches='tight')
# 同时保存PDF版本
if save_path.endswith('.png'):
pdf_path = save_path.replace('.png', '.pdf')
plt.savefig(pdf_path, bbox_inches='tight')
print(f"图片已保存至: {save_path}")
return results_df, fig
except Exception as e:
raise Exception(f"绘制序列预测图时出错: {str(e)}")
def predict_regions(self, sequences, short_threshold=0.1, ensemble_weight=0.4):
'''
预测区域序列批量预测已知的399bp序列
Args:
sequences: 399bp序列或包含399bp序列的DataFrame/Series/list
short_threshold: short模型概率阈值 (默认为0.1)
ensemble_weight: short模型在集成中的权重 (默认为0.4)
Returns:
DataFrame: 包含所有序列预测概率的DataFrame
'''
try:
# 验证权重参数
if not (0.0 <= ensemble_weight <= 1.0):
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
# 统一输入格式
if isinstance(sequences, (pd.DataFrame, pd.Series)):
sequences = sequences.tolist()
elif isinstance(sequences, str):
sequences = [sequences]
results = []
for i, seq399 in enumerate(sequences):
try:
# 从399bp序列中截取中心的33bp (short模型使用)
seq33 = self._extract_center_sequence(seq399, target_length=self.short_seq_length)
# 使用统一的预测方法
pred_result = self.predict_single_position(seq33, seq399, short_threshold, ensemble_weight)
pred_result.update({
'Short_Sequence': seq33,
'Long_Sequence': seq399
})
results.append(pred_result)
except Exception as e:
print(f"处理第 {i+1} 个序列时出错: {str(e)}")
long_weight = 1.0 - ensemble_weight
results.append({
'Short_Probability': 0.0,
'Long_Probability': 0.0,
'Ensemble_Probability': 0.0,
'Ensemble_Weights': f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}',
'Short_Sequence': self._extract_center_sequence(seq399, target_length=self.short_seq_length) if len(seq399) >= self.short_seq_length else seq399,
'Long_Sequence': seq399
})
return pd.DataFrame(results)
except Exception as e:
raise Exception(f"区域预测过程出错: {str(e)}")
def _extract_center_sequence(self, sequence, target_length=33):
"""从序列中心位置提取指定长度的子序列"""
# 确保序列为字符串
sequence = str(sequence).upper()
# 如果序列长度小于目标长度,返回原序列
if len(sequence) <= target_length:
return sequence
# 计算中心位置
center = len(sequence) // 2
half_target = target_length // 2
# 提取中心序列
start = center - half_target
end = start + target_length
# 边界检查
if start < 0:
start = 0
end = target_length
elif end > len(sequence):
end = len(sequence)
start = end - target_length
return sequence[start:end]
# 兼容性方法(向后兼容,但标记为废弃)
def predict_full(self, sequence, window_size=3, short_threshold=0.1, short_weight=0.4, plot=False):
"""
⚠️ 已废弃:请使用 predict_sequence() 方法
向后兼容的方法,内部调用新的 predict_sequence()
"""
import warnings
warnings.warn("predict_full() 已废弃,请使用 predict_sequence() 方法", DeprecationWarning, stacklevel=2)
# 调用新方法并添加兼容性字段
results_df = self.predict_sequence(sequence, window_size, short_threshold, short_weight)
# 添加兼容性字段
if 'Ensemble_Probability' in results_df.columns:
results_df['Voting_Probability'] = results_df['Ensemble_Probability']
results_df['Weighted_Probability'] = results_df['Ensemble_Probability']
if 'Ensemble_Weights' in results_df.columns:
results_df['Weight_Info'] = results_df['Ensemble_Weights']
if 'Short_Sequence' in results_df.columns:
results_df['33bp'] = results_df['Short_Sequence']
if 'Long_Sequence' in results_df.columns:
results_df['399bp'] = results_df['Long_Sequence']
if plot:
# 如果需要绘图,调用绘图方法
_, fig = self.plot_sequence_prediction(sequence, window_size, 0.65, 0.8, short_weight)
return results_df, fig
return results_df
def predict_region(self, seq, short_threshold=0.1, short_weight=0.4):
"""
⚠️ 已废弃:请使用 predict_regions() 方法
向后兼容的方法,内部调用新的 predict_regions()
"""
import warnings
warnings.warn("predict_region() 已废弃,请使用 predict_regions() 方法", DeprecationWarning, stacklevel=2)
# 调用新方法并添加兼容性字段
results_df = self.predict_regions(seq, short_threshold, short_weight)
# 添加兼容性字段
if 'Ensemble_Probability' in results_df.columns:
results_df['Voting_Probability'] = results_df['Ensemble_Probability']
results_df['Weighted_Probability'] = results_df['Ensemble_Probability']
if 'Ensemble_Weights' in results_df.columns:
results_df['Weight_Info'] = results_df['Ensemble_Weights']
if 'Short_Sequence' in results_df.columns:
results_df['33bp'] = results_df['Short_Sequence']
if 'Long_Sequence' in results_df.columns:
results_df['399bp'] = results_df['Long_Sequence']
return results_df