2025-03-18 11:21:54 +08:00
|
|
|
|
import os
|
|
|
|
|
|
import pickle
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
import pandas as pd
|
|
|
|
|
|
from tensorflow.keras.models import load_model
|
|
|
|
|
|
from .features.sequence import SequenceFeatureExtractor
|
|
|
|
|
|
from .features.cnn_input import CNNInputProcessor
|
|
|
|
|
|
from .utils import extract_window_sequences
|
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
import joblib
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PRFPredictor:
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, model_dir=None):
|
|
|
|
|
|
"""
|
|
|
|
|
|
初始化PRF预测器
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
model_dir: 模型目录路径(可选)
|
|
|
|
|
|
"""
|
|
|
|
|
|
if model_dir is None:
|
|
|
|
|
|
from pkg_resources import resource_filename
|
|
|
|
|
|
model_dir = resource_filename('FScanpy', 'pretrained')
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
2025-05-29 17:58:48 +08:00
|
|
|
|
# 加载模型 - 使用新的命名约定
|
|
|
|
|
|
self.short_model = self._load_pickle(os.path.join(model_dir, 'short.pkl')) # HistGB模型
|
|
|
|
|
|
self.long_model = self._load_pickle(os.path.join(model_dir, 'long.pkl')) # BiLSTM-CNN模型
|
2025-05-29 14:40:01 +08:00
|
|
|
|
|
2025-03-18 11:21:54 +08:00
|
|
|
|
# 初始化特征提取器和CNN处理器,使用与训练时相同的序列长度
|
2025-05-29 17:58:48 +08:00
|
|
|
|
self.short_seq_length = 33 # HistGB使用的序列长度
|
|
|
|
|
|
self.long_seq_length = 399 # BiLSTM-CNN使用的序列长度
|
2025-03-18 11:21:54 +08:00
|
|
|
|
|
|
|
|
|
|
# 初始化特征提取器和CNN输入处理器
|
2025-05-29 17:58:48 +08:00
|
|
|
|
self.feature_extractor = SequenceFeatureExtractor(seq_length=self.short_seq_length)
|
|
|
|
|
|
self.cnn_processor = CNNInputProcessor(max_length=self.long_seq_length)
|
|
|
|
|
|
|
|
|
|
|
|
# 检测模型类型以优化预测性能
|
|
|
|
|
|
self._detect_model_types()
|
2025-03-18 11:21:54 +08:00
|
|
|
|
|
|
|
|
|
|
except FileNotFoundError as e:
|
2025-05-29 17:58:48 +08:00
|
|
|
|
raise FileNotFoundError(f"无法找到模型文件: {str(e)}。请确保模型文件 'short.pkl' 和 'long.pkl' 存在于 {model_dir}")
|
2025-03-18 11:21:54 +08:00
|
|
|
|
except Exception as e:
|
|
|
|
|
|
raise Exception(f"加载模型出错: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
def _load_pickle(self, path):
|
2025-05-29 17:58:48 +08:00
|
|
|
|
"""安全加载pickle文件"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
return joblib.load(path)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
raise FileNotFoundError(f"无法加载模型文件 {path}: {str(e)}")
|
2025-03-18 11:21:54 +08:00
|
|
|
|
|
2025-05-29 17:58:48 +08:00
|
|
|
|
def _detect_model_types(self):
|
|
|
|
|
|
"""检测模型类型以优化预测性能"""
|
|
|
|
|
|
self.short_is_sklearn = hasattr(self.short_model, 'predict_proba')
|
|
|
|
|
|
self.long_is_sklearn = hasattr(self.long_model, 'predict_proba')
|
|
|
|
|
|
|
|
|
|
|
|
def _predict_model(self, model, features, is_sklearn, seq_length):
|
|
|
|
|
|
"""统一的模型预测方法"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
if is_sklearn:
|
|
|
|
|
|
# sklearn模型使用特征向量
|
|
|
|
|
|
if isinstance(features, np.ndarray) and features.ndim > 1:
|
|
|
|
|
|
features = features.flatten()
|
|
|
|
|
|
features_2d = np.array([features])
|
|
|
|
|
|
pred = model.predict_proba(features_2d)
|
|
|
|
|
|
return pred[0][1]
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 深度学习模型
|
|
|
|
|
|
if seq_length == self.long_seq_length:
|
|
|
|
|
|
# 对于长序列,使用CNN处理器
|
|
|
|
|
|
model_input = self.cnn_processor.prepare_sequence(features)
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 对于短序列,转换为数值编码
|
|
|
|
|
|
base_to_num = {'A': 1, 'T': 2, 'G': 3, 'C': 4, 'N': 0}
|
|
|
|
|
|
seq_numeric = [base_to_num.get(base, 0) for base in features.upper()]
|
|
|
|
|
|
model_input = np.array(seq_numeric).reshape(1, len(seq_numeric), 1)
|
|
|
|
|
|
|
|
|
|
|
|
# 统一的预测调用
|
|
|
|
|
|
try:
|
|
|
|
|
|
pred = model.predict(model_input, verbose=0)
|
|
|
|
|
|
except TypeError:
|
|
|
|
|
|
pred = model.predict(model_input)
|
|
|
|
|
|
|
|
|
|
|
|
# 处理预测结果
|
|
|
|
|
|
if isinstance(pred, list):
|
|
|
|
|
|
pred = pred[0]
|
|
|
|
|
|
|
|
|
|
|
|
if hasattr(pred, 'shape') and len(pred.shape) > 1 and pred.shape[1] > 1:
|
|
|
|
|
|
return pred[0][1]
|
|
|
|
|
|
else:
|
|
|
|
|
|
return pred[0][0] if hasattr(pred[0], '__getitem__') else pred[0]
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
raise Exception(f"模型预测失败: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
def predict_single_position(self, fs_period, full_seq, short_threshold=0.1, ensemble_weight=0.4):
|
2025-03-18 11:21:54 +08:00
|
|
|
|
'''
|
|
|
|
|
|
预测单个位置的PRF状态
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
2025-05-29 17:58:48 +08:00
|
|
|
|
fs_period: 33bp序列 (short模型使用)
|
|
|
|
|
|
full_seq: 完整序列 (long模型使用)
|
|
|
|
|
|
short_threshold: short模型的概率阈值 (默认为0.1)
|
|
|
|
|
|
ensemble_weight: short模型在集成中的权重 (默认为0.4,long权重为0.6)
|
2025-03-18 11:21:54 +08:00
|
|
|
|
Returns:
|
|
|
|
|
|
dict: 包含预测概率的字典
|
|
|
|
|
|
'''
|
|
|
|
|
|
try:
|
2025-05-29 17:58:48 +08:00
|
|
|
|
# 验证权重参数
|
|
|
|
|
|
if not (0.0 <= ensemble_weight <= 1.0):
|
|
|
|
|
|
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
|
|
|
|
|
|
|
|
|
|
|
|
long_weight = 1.0 - ensemble_weight
|
|
|
|
|
|
|
2025-03-18 11:21:54 +08:00
|
|
|
|
# 处理序列长度
|
2025-05-29 17:58:48 +08:00
|
|
|
|
if len(fs_period) > self.short_seq_length:
|
|
|
|
|
|
fs_period = self.feature_extractor.trim_sequence(fs_period, self.short_seq_length)
|
2025-03-18 11:21:54 +08:00
|
|
|
|
|
2025-05-29 17:58:48 +08:00
|
|
|
|
# Short模型预测 (HistGB)
|
2025-03-18 11:21:54 +08:00
|
|
|
|
try:
|
2025-05-29 17:58:48 +08:00
|
|
|
|
if self.short_is_sklearn:
|
|
|
|
|
|
short_features = self.feature_extractor.extract_features(fs_period)
|
|
|
|
|
|
short_prob = self._predict_model(self.short_model, short_features, True, self.short_seq_length)
|
|
|
|
|
|
else:
|
|
|
|
|
|
short_prob = self._predict_model(self.short_model, fs_period, False, self.short_seq_length)
|
2025-03-18 11:21:54 +08:00
|
|
|
|
except Exception as e:
|
2025-05-29 17:58:48 +08:00
|
|
|
|
print(f"Short模型预测时出错: {str(e)}")
|
|
|
|
|
|
short_prob = 0.0
|
2025-03-18 11:21:54 +08:00
|
|
|
|
|
2025-05-29 17:58:48 +08:00
|
|
|
|
# 如果short概率低于阈值,则跳过long模型
|
|
|
|
|
|
if short_prob < short_threshold:
|
2025-03-18 11:21:54 +08:00
|
|
|
|
return {
|
2025-05-29 17:58:48 +08:00
|
|
|
|
'Short_Probability': short_prob,
|
|
|
|
|
|
'Long_Probability': 0.0,
|
|
|
|
|
|
'Ensemble_Probability': 0.0,
|
|
|
|
|
|
'Ensemble_Weights': f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}'
|
2025-03-18 11:21:54 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2025-05-29 17:58:48 +08:00
|
|
|
|
# Long模型预测 (BiLSTM-CNN)
|
2025-03-18 11:21:54 +08:00
|
|
|
|
try:
|
2025-05-29 17:58:48 +08:00
|
|
|
|
if self.long_is_sklearn:
|
|
|
|
|
|
long_features = self.feature_extractor.extract_features(full_seq)
|
|
|
|
|
|
long_prob = self._predict_model(self.long_model, long_features, True, self.long_seq_length)
|
2025-03-18 11:21:54 +08:00
|
|
|
|
else:
|
2025-05-29 17:58:48 +08:00
|
|
|
|
long_prob = self._predict_model(self.long_model, full_seq, False, self.long_seq_length)
|
2025-03-18 11:21:54 +08:00
|
|
|
|
except Exception as e:
|
2025-05-29 17:58:48 +08:00
|
|
|
|
print(f"Long模型预测时出错: {str(e)}")
|
|
|
|
|
|
long_prob = 0.0
|
2025-03-18 11:21:54 +08:00
|
|
|
|
|
2025-05-29 17:58:48 +08:00
|
|
|
|
# 计算集成概率
|
2025-03-18 11:21:54 +08:00
|
|
|
|
try:
|
2025-05-29 17:58:48 +08:00
|
|
|
|
ensemble_prob = ensemble_weight * short_prob + long_weight * long_prob
|
2025-03-18 11:21:54 +08:00
|
|
|
|
except Exception as e:
|
2025-05-29 17:58:48 +08:00
|
|
|
|
print(f"计算集成概率时出错: {str(e)}")
|
|
|
|
|
|
ensemble_prob = (short_prob + long_prob) / 2
|
2025-03-18 11:21:54 +08:00
|
|
|
|
|
|
|
|
|
|
return {
|
2025-05-29 17:58:48 +08:00
|
|
|
|
'Short_Probability': short_prob,
|
|
|
|
|
|
'Long_Probability': long_prob,
|
|
|
|
|
|
'Ensemble_Probability': ensemble_prob,
|
|
|
|
|
|
'Ensemble_Weights': f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}'
|
2025-03-18 11:21:54 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
raise Exception(f"预测过程出错: {str(e)}")
|
|
|
|
|
|
|
2025-05-29 17:58:48 +08:00
|
|
|
|
def predict_sequence(self, sequence, window_size=3, short_threshold=0.1, ensemble_weight=0.4):
|
2025-03-18 11:21:54 +08:00
|
|
|
|
"""
|
2025-05-29 17:58:48 +08:00
|
|
|
|
预测完整序列中的PRF位点(滑动窗口方法)
|
2025-03-18 11:21:54 +08:00
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
sequence: 输入DNA序列
|
|
|
|
|
|
window_size: 滑动窗口大小 (默认为3)
|
2025-05-29 17:58:48 +08:00
|
|
|
|
short_threshold: short模型概率阈值 (默认为0.1)
|
|
|
|
|
|
ensemble_weight: short模型在集成中的权重 (默认为0.4)
|
2025-03-18 11:21:54 +08:00
|
|
|
|
|
|
|
|
|
|
Returns:
|
2025-05-29 17:58:48 +08:00
|
|
|
|
pd.DataFrame: 包含预测结果的DataFrame
|
2025-03-18 11:21:54 +08:00
|
|
|
|
"""
|
|
|
|
|
|
if window_size < 1:
|
|
|
|
|
|
raise ValueError("窗口大小必须大于等于1")
|
2025-05-29 17:58:48 +08:00
|
|
|
|
if short_threshold < 0:
|
|
|
|
|
|
raise ValueError("short模型阈值必须大于等于0")
|
|
|
|
|
|
if not (0.0 <= ensemble_weight <= 1.0):
|
|
|
|
|
|
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
|
2025-03-18 11:21:54 +08:00
|
|
|
|
|
|
|
|
|
|
results = []
|
2025-05-29 17:58:48 +08:00
|
|
|
|
long_weight = 1.0 - ensemble_weight
|
2025-03-18 11:21:54 +08:00
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 确保序列为字符串并转换为大写
|
|
|
|
|
|
sequence = str(sequence).upper()
|
|
|
|
|
|
|
|
|
|
|
|
# 滑动窗口预测
|
|
|
|
|
|
for pos in range(0, len(sequence) - 2, window_size):
|
2025-05-29 17:58:48 +08:00
|
|
|
|
# 提取窗口序列
|
2025-03-18 11:21:54 +08:00
|
|
|
|
fs_period, full_seq = extract_window_sequences(sequence, pos)
|
|
|
|
|
|
|
|
|
|
|
|
if fs_period is None or full_seq is None:
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
# 预测并记录结果
|
2025-05-29 17:58:48 +08:00
|
|
|
|
pred = self.predict_single_position(fs_period, full_seq, short_threshold, ensemble_weight)
|
2025-03-18 11:21:54 +08:00
|
|
|
|
pred.update({
|
|
|
|
|
|
'Position': pos,
|
|
|
|
|
|
'Codon': sequence[pos:pos+3],
|
2025-05-29 17:58:48 +08:00
|
|
|
|
'Short_Sequence': fs_period, # 更清晰的命名
|
|
|
|
|
|
'Long_Sequence': full_seq # 更清晰的命名
|
2025-03-18 11:21:54 +08:00
|
|
|
|
})
|
|
|
|
|
|
results.append(pred)
|
|
|
|
|
|
|
|
|
|
|
|
# 创建结果DataFrame
|
|
|
|
|
|
results_df = pd.DataFrame(results)
|
|
|
|
|
|
|
|
|
|
|
|
return results_df
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
raise Exception(f"序列预测过程出错: {str(e)}")
|
|
|
|
|
|
|
2025-05-29 17:58:48 +08:00
|
|
|
|
def plot_sequence_prediction(self, sequence, window_size=3, short_threshold=0.65,
|
|
|
|
|
|
long_threshold=0.8, ensemble_weight=0.4, title=None, save_path=None,
|
|
|
|
|
|
figsize=(12, 6), dpi=300):
|
|
|
|
|
|
"""
|
|
|
|
|
|
绘制序列预测结果的移码概率图
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
sequence: 输入DNA序列
|
|
|
|
|
|
window_size: 滑动窗口大小 (默认为3)
|
|
|
|
|
|
short_threshold: Short模型(HistGB)过滤阈值 (默认为0.65)
|
|
|
|
|
|
long_threshold: Long模型(BiLSTM-CNN)过滤阈值 (默认为0.8)
|
|
|
|
|
|
ensemble_weight: Short模型在集成中的权重 (默认为0.4)
|
|
|
|
|
|
title: 图片标题 (可选)
|
|
|
|
|
|
save_path: 保存路径 (可选,如果提供则保存图片)
|
|
|
|
|
|
figsize: 图片尺寸 (默认为(12, 6))
|
|
|
|
|
|
dpi: 图片分辨率 (默认为300)
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
tuple: (pd.DataFrame, matplotlib.figure.Figure) 预测结果和图形对象
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 验证权重参数
|
|
|
|
|
|
if not (0.0 <= ensemble_weight <= 1.0):
|
|
|
|
|
|
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
|
|
|
|
|
|
|
|
|
|
|
|
long_weight = 1.0 - ensemble_weight
|
|
|
|
|
|
|
|
|
|
|
|
# 获取预测结果 - 使用新的方法名
|
|
|
|
|
|
results_df = self.predict_sequence(sequence, window_size=window_size,
|
|
|
|
|
|
short_threshold=0.1, ensemble_weight=ensemble_weight)
|
|
|
|
|
|
|
|
|
|
|
|
if results_df.empty:
|
|
|
|
|
|
raise ValueError("预测结果为空,请检查输入序列")
|
|
|
|
|
|
|
|
|
|
|
|
# 获取序列长度
|
|
|
|
|
|
seq_length = len(sequence)
|
|
|
|
|
|
|
|
|
|
|
|
# 计算显示宽度
|
|
|
|
|
|
prob_width = max(1, seq_length // 300) # 概率标记的宽度
|
|
|
|
|
|
|
|
|
|
|
|
# 创建图形,包含两个子图
|
|
|
|
|
|
fig = plt.figure(figsize=figsize)
|
|
|
|
|
|
gs = fig.add_gridspec(2, 1, height_ratios=[0.15, 1], hspace=0.3)
|
|
|
|
|
|
|
|
|
|
|
|
# 设置标题
|
|
|
|
|
|
if title:
|
|
|
|
|
|
fig.suptitle(title, y=0.95, fontsize=12)
|
|
|
|
|
|
else:
|
|
|
|
|
|
fig.suptitle(f'序列移码概率预测结果 (权重 {ensemble_weight:.1f}:{long_weight:.1f})', y=0.95, fontsize=12)
|
|
|
|
|
|
|
|
|
|
|
|
# 预测概率热图
|
|
|
|
|
|
ax0 = fig.add_subplot(gs[0])
|
|
|
|
|
|
prob_data = np.zeros((1, seq_length))
|
|
|
|
|
|
|
|
|
|
|
|
# 应用双重阈值过滤
|
|
|
|
|
|
for _, row in results_df.iterrows():
|
|
|
|
|
|
pos = int(row['Position'])
|
|
|
|
|
|
if (row['Short_Probability'] >= short_threshold and
|
|
|
|
|
|
row['Long_Probability'] >= long_threshold):
|
|
|
|
|
|
# 为每个满足阈值的位置设置概率值
|
|
|
|
|
|
start = max(0, pos - prob_width//2)
|
|
|
|
|
|
end = min(seq_length, pos + prob_width//2 + 1)
|
|
|
|
|
|
prob_data[0, start:end] = row['Ensemble_Probability']
|
|
|
|
|
|
|
|
|
|
|
|
im = ax0.imshow(prob_data, cmap='Reds', aspect='auto', vmin=0, vmax=1,
|
|
|
|
|
|
interpolation='nearest')
|
|
|
|
|
|
ax0.set_xticks([])
|
|
|
|
|
|
ax0.set_yticks([])
|
|
|
|
|
|
ax0.set_title(f'预测概率热图 (Short≥{short_threshold}, Long≥{long_threshold})',
|
|
|
|
|
|
pad=5, fontsize=10)
|
|
|
|
|
|
|
|
|
|
|
|
# 主图(条形图)
|
|
|
|
|
|
ax1 = fig.add_subplot(gs[1])
|
|
|
|
|
|
|
|
|
|
|
|
# 应用过滤阈值
|
|
|
|
|
|
filtered_probs = results_df['Ensemble_Probability'].copy()
|
|
|
|
|
|
mask = ((results_df['Short_Probability'] < short_threshold) |
|
|
|
|
|
|
(results_df['Long_Probability'] < long_threshold))
|
|
|
|
|
|
filtered_probs[mask] = 0
|
|
|
|
|
|
|
|
|
|
|
|
# 绘制条形图
|
|
|
|
|
|
bars = ax1.bar(results_df['Position'], filtered_probs,
|
|
|
|
|
|
alpha=0.7, color='darkred', width=max(1, window_size))
|
|
|
|
|
|
|
|
|
|
|
|
# 设置x轴刻度
|
|
|
|
|
|
step = max(seq_length // 10, 50)
|
|
|
|
|
|
x_ticks = np.arange(0, seq_length, step)
|
|
|
|
|
|
ax1.set_xticks(x_ticks)
|
|
|
|
|
|
ax1.tick_params(axis='x', rotation=45)
|
|
|
|
|
|
|
|
|
|
|
|
# 设置标签和标题
|
|
|
|
|
|
ax1.set_xlabel('序列位置 (bp)', fontsize=10)
|
|
|
|
|
|
ax1.set_ylabel('移码概率', fontsize=10)
|
|
|
|
|
|
ax1.set_title(f'移码概率分布 (集成权重 {ensemble_weight:.1f}:{long_weight:.1f})', fontsize=11)
|
|
|
|
|
|
|
|
|
|
|
|
# 设置y轴范围
|
|
|
|
|
|
ax1.set_ylim(0, 1)
|
|
|
|
|
|
|
|
|
|
|
|
# 添加网格
|
|
|
|
|
|
ax1.grid(True, alpha=0.3)
|
|
|
|
|
|
|
|
|
|
|
|
# 添加阈值和权重说明
|
|
|
|
|
|
info_text = (f'过滤阈值: Short≥{short_threshold}, Long≥{long_threshold}\n'
|
|
|
|
|
|
f'集成权重: Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}')
|
|
|
|
|
|
ax1.text(0.02, 0.95, info_text, transform=ax1.transAxes,
|
|
|
|
|
|
fontsize=9, verticalalignment='top',
|
|
|
|
|
|
bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8))
|
|
|
|
|
|
|
|
|
|
|
|
# 确保所有子图的x轴范围一致
|
|
|
|
|
|
for ax in [ax0, ax1]:
|
|
|
|
|
|
ax.set_xlim(-1, seq_length)
|
|
|
|
|
|
|
|
|
|
|
|
# 调整布局
|
|
|
|
|
|
plt.tight_layout()
|
|
|
|
|
|
|
|
|
|
|
|
# 如果提供了保存路径,则保存图片
|
|
|
|
|
|
if save_path:
|
|
|
|
|
|
plt.savefig(save_path, dpi=dpi, bbox_inches='tight')
|
|
|
|
|
|
# 同时保存PDF版本
|
|
|
|
|
|
if save_path.endswith('.png'):
|
|
|
|
|
|
pdf_path = save_path.replace('.png', '.pdf')
|
|
|
|
|
|
plt.savefig(pdf_path, bbox_inches='tight')
|
|
|
|
|
|
print(f"图片已保存至: {save_path}")
|
|
|
|
|
|
|
|
|
|
|
|
return results_df, fig
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
raise Exception(f"绘制序列预测图时出错: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
def predict_regions(self, sequences, short_threshold=0.1, ensemble_weight=0.4):
|
2025-03-18 11:21:54 +08:00
|
|
|
|
'''
|
2025-05-29 17:58:48 +08:00
|
|
|
|
预测区域序列(批量预测已知的399bp序列)
|
2025-03-18 11:21:54 +08:00
|
|
|
|
|
|
|
|
|
|
Args:
|
2025-05-29 17:58:48 +08:00
|
|
|
|
sequences: 399bp序列或包含399bp序列的DataFrame/Series/list
|
|
|
|
|
|
short_threshold: short模型概率阈值 (默认为0.1)
|
|
|
|
|
|
ensemble_weight: short模型在集成中的权重 (默认为0.4)
|
2025-03-18 11:21:54 +08:00
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
DataFrame: 包含所有序列预测概率的DataFrame
|
|
|
|
|
|
'''
|
|
|
|
|
|
try:
|
2025-05-29 17:58:48 +08:00
|
|
|
|
# 验证权重参数
|
|
|
|
|
|
if not (0.0 <= ensemble_weight <= 1.0):
|
|
|
|
|
|
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
|
2025-03-18 11:21:54 +08:00
|
|
|
|
|
2025-05-29 17:58:48 +08:00
|
|
|
|
# 统一输入格式
|
|
|
|
|
|
if isinstance(sequences, (pd.DataFrame, pd.Series)):
|
|
|
|
|
|
sequences = sequences.tolist()
|
|
|
|
|
|
elif isinstance(sequences, str):
|
|
|
|
|
|
sequences = [sequences]
|
2025-03-18 11:21:54 +08:00
|
|
|
|
|
|
|
|
|
|
results = []
|
2025-05-29 17:58:48 +08:00
|
|
|
|
for i, seq399 in enumerate(sequences):
|
2025-03-18 11:21:54 +08:00
|
|
|
|
try:
|
2025-05-29 17:58:48 +08:00
|
|
|
|
# 从399bp序列中截取中心的33bp (short模型使用)
|
|
|
|
|
|
seq33 = self._extract_center_sequence(seq399, target_length=self.short_seq_length)
|
2025-03-18 11:21:54 +08:00
|
|
|
|
|
2025-05-29 17:58:48 +08:00
|
|
|
|
# 使用统一的预测方法
|
|
|
|
|
|
pred_result = self.predict_single_position(seq33, seq399, short_threshold, ensemble_weight)
|
|
|
|
|
|
pred_result.update({
|
|
|
|
|
|
'Short_Sequence': seq33,
|
|
|
|
|
|
'Long_Sequence': seq399
|
2025-03-18 11:21:54 +08:00
|
|
|
|
})
|
|
|
|
|
|
|
2025-05-29 17:58:48 +08:00
|
|
|
|
results.append(pred_result)
|
|
|
|
|
|
|
2025-03-18 11:21:54 +08:00
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"处理第 {i+1} 个序列时出错: {str(e)}")
|
2025-05-29 17:58:48 +08:00
|
|
|
|
long_weight = 1.0 - ensemble_weight
|
2025-03-18 11:21:54 +08:00
|
|
|
|
results.append({
|
2025-05-29 17:58:48 +08:00
|
|
|
|
'Short_Probability': 0.0,
|
|
|
|
|
|
'Long_Probability': 0.0,
|
|
|
|
|
|
'Ensemble_Probability': 0.0,
|
|
|
|
|
|
'Ensemble_Weights': f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}',
|
|
|
|
|
|
'Short_Sequence': self._extract_center_sequence(seq399, target_length=self.short_seq_length) if len(seq399) >= self.short_seq_length else seq399,
|
|
|
|
|
|
'Long_Sequence': seq399
|
2025-03-18 11:21:54 +08:00
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
return pd.DataFrame(results)
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
raise Exception(f"区域预测过程出错: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
def _extract_center_sequence(self, sequence, target_length=33):
|
|
|
|
|
|
"""从序列中心位置提取指定长度的子序列"""
|
|
|
|
|
|
# 确保序列为字符串
|
|
|
|
|
|
sequence = str(sequence).upper()
|
|
|
|
|
|
|
|
|
|
|
|
# 如果序列长度小于目标长度,返回原序列
|
|
|
|
|
|
if len(sequence) <= target_length:
|
|
|
|
|
|
return sequence
|
|
|
|
|
|
|
|
|
|
|
|
# 计算中心位置
|
|
|
|
|
|
center = len(sequence) // 2
|
|
|
|
|
|
half_target = target_length // 2
|
|
|
|
|
|
|
|
|
|
|
|
# 提取中心序列
|
|
|
|
|
|
start = center - half_target
|
|
|
|
|
|
end = start + target_length
|
|
|
|
|
|
|
|
|
|
|
|
# 边界检查
|
|
|
|
|
|
if start < 0:
|
|
|
|
|
|
start = 0
|
|
|
|
|
|
end = target_length
|
|
|
|
|
|
elif end > len(sequence):
|
|
|
|
|
|
end = len(sequence)
|
|
|
|
|
|
start = end - target_length
|
|
|
|
|
|
|
2025-05-29 17:58:48 +08:00
|
|
|
|
return sequence[start:end]
|
|
|
|
|
|
|
|
|
|
|
|
# 兼容性方法(向后兼容,但标记为废弃)
|
|
|
|
|
|
def predict_full(self, sequence, window_size=3, short_threshold=0.1, short_weight=0.4, plot=False):
|
|
|
|
|
|
"""
|
|
|
|
|
|
⚠️ 已废弃:请使用 predict_sequence() 方法
|
|
|
|
|
|
|
|
|
|
|
|
向后兼容的方法,内部调用新的 predict_sequence()
|
|
|
|
|
|
"""
|
|
|
|
|
|
import warnings
|
|
|
|
|
|
warnings.warn("predict_full() 已废弃,请使用 predict_sequence() 方法", DeprecationWarning, stacklevel=2)
|
|
|
|
|
|
|
|
|
|
|
|
# 调用新方法并添加兼容性字段
|
|
|
|
|
|
results_df = self.predict_sequence(sequence, window_size, short_threshold, short_weight)
|
|
|
|
|
|
|
|
|
|
|
|
# 添加兼容性字段
|
|
|
|
|
|
if 'Ensemble_Probability' in results_df.columns:
|
|
|
|
|
|
results_df['Voting_Probability'] = results_df['Ensemble_Probability']
|
|
|
|
|
|
results_df['Weighted_Probability'] = results_df['Ensemble_Probability']
|
|
|
|
|
|
if 'Ensemble_Weights' in results_df.columns:
|
|
|
|
|
|
results_df['Weight_Info'] = results_df['Ensemble_Weights']
|
|
|
|
|
|
if 'Short_Sequence' in results_df.columns:
|
|
|
|
|
|
results_df['33bp'] = results_df['Short_Sequence']
|
|
|
|
|
|
if 'Long_Sequence' in results_df.columns:
|
|
|
|
|
|
results_df['399bp'] = results_df['Long_Sequence']
|
|
|
|
|
|
|
|
|
|
|
|
if plot:
|
|
|
|
|
|
# 如果需要绘图,调用绘图方法
|
|
|
|
|
|
_, fig = self.plot_sequence_prediction(sequence, window_size, 0.65, 0.8, short_weight)
|
|
|
|
|
|
return results_df, fig
|
|
|
|
|
|
|
|
|
|
|
|
return results_df
|
|
|
|
|
|
|
|
|
|
|
|
def predict_region(self, seq, short_threshold=0.1, short_weight=0.4):
|
|
|
|
|
|
"""
|
|
|
|
|
|
⚠️ 已废弃:请使用 predict_regions() 方法
|
|
|
|
|
|
|
|
|
|
|
|
向后兼容的方法,内部调用新的 predict_regions()
|
|
|
|
|
|
"""
|
|
|
|
|
|
import warnings
|
|
|
|
|
|
warnings.warn("predict_region() 已废弃,请使用 predict_regions() 方法", DeprecationWarning, stacklevel=2)
|
|
|
|
|
|
|
|
|
|
|
|
# 调用新方法并添加兼容性字段
|
|
|
|
|
|
results_df = self.predict_regions(seq, short_threshold, short_weight)
|
|
|
|
|
|
|
|
|
|
|
|
# 添加兼容性字段
|
|
|
|
|
|
if 'Ensemble_Probability' in results_df.columns:
|
|
|
|
|
|
results_df['Voting_Probability'] = results_df['Ensemble_Probability']
|
|
|
|
|
|
results_df['Weighted_Probability'] = results_df['Ensemble_Probability']
|
|
|
|
|
|
if 'Ensemble_Weights' in results_df.columns:
|
|
|
|
|
|
results_df['Weight_Info'] = results_df['Ensemble_Weights']
|
|
|
|
|
|
if 'Short_Sequence' in results_df.columns:
|
|
|
|
|
|
results_df['33bp'] = results_df['Short_Sequence']
|
|
|
|
|
|
if 'Long_Sequence' in results_df.columns:
|
|
|
|
|
|
results_df['399bp'] = results_df['Long_Sequence']
|
|
|
|
|
|
|
|
|
|
|
|
return results_df
|