FScanpy-package/FScanpy/predictor.py

499 lines
22 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import pickle
import numpy as np
import pandas as pd
from tensorflow.keras.models import load_model
from .features.sequence import SequenceFeatureExtractor
from .features.cnn_input import CNNInputProcessor
from .utils import extract_window_sequences
import matplotlib.pyplot as plt
import joblib
class PRFPredictor:
def __init__(self, model_dir=None):
"""
初始化PRF预测器
Args:
model_dir: 模型目录路径(可选)
"""
if model_dir is None:
from pkg_resources import resource_filename
model_dir = resource_filename('FScanpy', 'pretrained')
try:
# 加载模型 - 使用新的命名约定
self.short_model = self._load_pickle(os.path.join(model_dir, 'short.pkl')) # HistGB模型
self.long_model = self._load_pickle(os.path.join(model_dir, 'long.pkl')) # BiLSTM-CNN模型
# 初始化特征提取器和CNN处理器使用与训练时相同的序列长度
self.short_seq_length = 33 # HistGB使用的序列长度
self.long_seq_length = 399 # BiLSTM-CNN使用的序列长度
# 初始化特征提取器和CNN输入处理器
self.feature_extractor = SequenceFeatureExtractor(seq_length=self.short_seq_length)
self.cnn_processor = CNNInputProcessor(max_length=self.long_seq_length)
# 检测模型类型以优化预测性能
self._detect_model_types()
except FileNotFoundError as e:
raise FileNotFoundError(f"无法找到模型文件: {str(e)}。请确保模型文件 'short.pkl''long.pkl' 存在于 {model_dir}")
except Exception as e:
raise Exception(f"加载模型出错: {str(e)}")
def _load_pickle(self, path):
"""安全加载pickle文件"""
try:
return joblib.load(path)
except Exception as e:
raise FileNotFoundError(f"无法加载模型文件 {path}: {str(e)}")
def _detect_model_types(self):
"""检测模型类型以优化预测性能"""
self.short_is_sklearn = hasattr(self.short_model, 'predict_proba')
self.long_is_sklearn = hasattr(self.long_model, 'predict_proba')
def _predict_model(self, model, features, is_sklearn, seq_length):
"""统一的模型预测方法"""
try:
if is_sklearn:
# sklearn模型使用特征向量
if isinstance(features, np.ndarray) and features.ndim > 1:
features = features.flatten()
features_2d = np.array([features])
pred = model.predict_proba(features_2d)
return pred[0][1]
else:
# 深度学习模型
if seq_length == self.long_seq_length:
# 对于长序列使用CNN处理器
model_input = self.cnn_processor.prepare_sequence(features)
else:
# 对于短序列,转换为数值编码
base_to_num = {'A': 1, 'T': 2, 'G': 3, 'C': 4, 'N': 0}
seq_numeric = [base_to_num.get(base, 0) for base in features.upper()]
model_input = np.array(seq_numeric).reshape(1, len(seq_numeric), 1)
# 统一的预测调用
try:
pred = model.predict(model_input, verbose=0)
except TypeError:
pred = model.predict(model_input)
# 处理预测结果
if isinstance(pred, list):
pred = pred[0]
if hasattr(pred, 'shape') and len(pred.shape) > 1 and pred.shape[1] > 1:
return pred[0][1]
else:
return pred[0][0] if hasattr(pred[0], '__getitem__') else pred[0]
except Exception as e:
raise Exception(f"模型预测失败: {str(e)}")
def predict_single_position(self, fs_period, full_seq, short_threshold=0.1, ensemble_weight=0.4):
'''
预测单个位置的PRF状态
Args:
fs_period: 33bp序列 (short模型使用)
full_seq: 完整序列 (long模型使用)
short_threshold: short模型的概率阈值 (默认为0.1)
ensemble_weight: short模型在集成中的权重 (默认为0.4long权重为0.6)
Returns:
dict: 包含预测概率的字典
'''
try:
# 验证权重参数
if not (0.0 <= ensemble_weight <= 1.0):
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
long_weight = 1.0 - ensemble_weight
# 处理序列长度
if len(fs_period) > self.short_seq_length:
fs_period = self.feature_extractor.trim_sequence(fs_period, self.short_seq_length)
# Short模型预测 (HistGB)
try:
if self.short_is_sklearn:
short_features = self.feature_extractor.extract_features(fs_period)
short_prob = self._predict_model(self.short_model, short_features, True, self.short_seq_length)
else:
short_prob = self._predict_model(self.short_model, fs_period, False, self.short_seq_length)
except Exception as e:
print(f"Short模型预测时出错: {str(e)}")
short_prob = 0.0
# 如果short概率低于阈值则跳过long模型
if short_prob < short_threshold:
return {
'Short_Probability': short_prob,
'Long_Probability': 0.0,
'Ensemble_Probability': 0.0,
'Ensemble_Weights': f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}'
}
# Long模型预测 (BiLSTM-CNN)
try:
if self.long_is_sklearn:
long_features = self.feature_extractor.extract_features(full_seq)
long_prob = self._predict_model(self.long_model, long_features, True, self.long_seq_length)
else:
long_prob = self._predict_model(self.long_model, full_seq, False, self.long_seq_length)
except Exception as e:
print(f"Long模型预测时出错: {str(e)}")
long_prob = 0.0
# 计算集成概率
try:
ensemble_prob = ensemble_weight * short_prob + long_weight * long_prob
except Exception as e:
print(f"计算集成概率时出错: {str(e)}")
ensemble_prob = (short_prob + long_prob) / 2
return {
'Short_Probability': short_prob,
'Long_Probability': long_prob,
'Ensemble_Probability': ensemble_prob,
'Ensemble_Weights': f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}'
}
except Exception as e:
raise Exception(f"预测过程出错: {str(e)}")
def predict_sequence(self, sequence, window_size=3, short_threshold=0.1, ensemble_weight=0.4):
"""
预测完整序列中的PRF位点滑动窗口方法
Args:
sequence: 输入DNA序列
window_size: 滑动窗口大小 (默认为3)
short_threshold: short模型概率阈值 (默认为0.1)
ensemble_weight: short模型在集成中的权重 (默认为0.4)
Returns:
pd.DataFrame: 包含预测结果的DataFrame
"""
if window_size < 1:
raise ValueError("窗口大小必须大于等于1")
if short_threshold < 0:
raise ValueError("short模型阈值必须大于等于0")
if not (0.0 <= ensemble_weight <= 1.0):
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
results = []
long_weight = 1.0 - ensemble_weight
try:
# 确保序列为字符串并转换为大写
sequence = str(sequence).upper()
# 滑动窗口预测
for pos in range(0, len(sequence) - 2, window_size):
# 提取窗口序列
fs_period, full_seq = extract_window_sequences(sequence, pos)
if fs_period is None or full_seq is None:
continue
# 预测并记录结果
pred = self.predict_single_position(fs_period, full_seq, short_threshold, ensemble_weight)
pred.update({
'Position': pos,
'Codon': sequence[pos:pos+3],
'Short_Sequence': fs_period, # 更清晰的命名
'Long_Sequence': full_seq # 更清晰的命名
})
results.append(pred)
# 创建结果DataFrame
results_df = pd.DataFrame(results)
return results_df
except Exception as e:
raise Exception(f"序列预测过程出错: {str(e)}")
def plot_sequence_prediction(self, sequence, window_size=3, short_threshold=0.65,
long_threshold=0.8, ensemble_weight=0.4, title=None, save_path=None,
figsize=(12, 8), dpi=300):
"""
Plot sequence PRF prediction results
Args:
sequence: Input DNA sequence
window_size: Sliding window size (default: 3)
short_threshold: Short model (HistGB) filtering threshold (default: 0.65)
long_threshold: Long model (BiLSTM-CNN) filtering threshold (default: 0.8)
ensemble_weight: Weight of short model in ensemble (default: 0.4)
title: Plot title (optional)
save_path: Save path (optional, saves plot if provided)
figsize: Figure size (default: (12, 8))
dpi: Figure resolution (default: 300)
Returns:
tuple: (pd.DataFrame, matplotlib.figure.Figure) prediction results and figure object
"""
try:
# Validate weight parameter
if not (0.0 <= ensemble_weight <= 1.0):
raise ValueError("ensemble_weight must be between 0.0 and 1.0")
long_weight = 1.0 - ensemble_weight
# Get prediction results
results_df = self.predict_sequence(sequence, window_size=window_size,
short_threshold=0.1, ensemble_weight=ensemble_weight)
if results_df.empty:
raise ValueError("Prediction results are empty, please check input sequence")
# Get sequence length
seq_length = len(sequence)
# Calculate display width
desired_visual_width = max(3, seq_length // 100) # FS site width ~1% of sequence length
prob_width = max(1, desired_visual_width // 3) # Prediction probability width is 1/3 of FS site width
# Create figure with three subplots, set height ratios
fig = plt.figure(figsize=figsize)
# Set title
if title:
fig.suptitle(title, y=0.95, fontsize=10)
else:
fig.suptitle(f'PRF Prediction Results (Weights {ensemble_weight:.1f}:{long_weight:.1f})', y=0.95, fontsize=10)
# Adjust subplot ratios, make top two heatmaps smaller
gs = fig.add_gridspec(3, 1, height_ratios=[0.1, 0.1, 1], hspace=0.2)
# FS site heatmap - using fixed width, no blur effect
ax0 = fig.add_subplot(gs[0])
fs_data = np.zeros((1, seq_length))
# Note: No actual FS site information in sliding window prediction, so keep empty or show predicted sites
# Show high-confidence predictions as potential FS sites
for _, row in results_df.iterrows():
pos = int(row['Position'])
if (row['Short_Probability'] >= short_threshold and
row['Long_Probability'] >= long_threshold and
row['Ensemble_Probability'] >= 0.8): # High confidence threshold
half_width = desired_visual_width // 2
start_pos = max(0, pos - half_width)
end_pos = min(seq_length, pos + half_width + 1)
fs_data[0, start_pos:end_pos] = 1 # Use fixed value, no gradient
ax0.imshow(fs_data, cmap='Reds', aspect='auto', interpolation='nearest')
ax0.set_xticks([])
ax0.set_yticks([])
ax0.set_title('FS site', pad=2, fontsize=8)
# Prediction probability heatmap - using fixed width to display probabilities
ax1 = fig.add_subplot(gs[1])
prob_data = np.zeros((1, seq_length))
# Apply dual threshold filtering
for _, row in results_df.iterrows():
pos = int(row['Position'])
if (row['Short_Probability'] >= short_threshold and
row['Long_Probability'] >= long_threshold):
# Set fixed width for each probability value
start = max(0, pos - prob_width//2)
end = min(seq_length, pos + prob_width//2 + 1)
prob_data[0, start:end] = row['Ensemble_Probability']
im = ax1.imshow(prob_data, cmap='Reds', aspect='auto', vmin=0, vmax=1, interpolation='nearest')
ax1.set_xticks([])
ax1.set_yticks([])
ax1.set_title('Prediction', pad=2, fontsize=8)
# Main plot (bar chart)
ax2 = fig.add_subplot(gs[2])
# Apply filtering thresholds
filtered_probs = results_df['Ensemble_Probability'].copy()
mask = ((results_df['Short_Probability'] < short_threshold) |
(results_df['Long_Probability'] < long_threshold))
filtered_probs[mask] = 0
# Draw bar chart - use black color and alpha=0.6 to match prediction_sample style
ax2.bar(results_df['Position'], filtered_probs,
alpha=0.6, color='black', width=1.0)
# Set x-axis ticks
step = max(seq_length // 10, 50)
ax2.set_xticks(np.arange(0, seq_length, step))
ax2.tick_params(axis='x', rotation=45)
# Set labels
ax2.set_xlabel('Position')
ax2.set_ylabel('Probability')
# Set y-axis range
ax2.set_ylim(0, 1)
# Add grid
ax2.grid(True, alpha=0.3)
# Ensure all subplots have consistent x-axis range
for ax in [ax0, ax1, ax2]:
ax.set_xlim(-1, seq_length)
# Adjust layout
plt.tight_layout()
# Save plot if save path is provided
if save_path:
plt.savefig(save_path, dpi=dpi, bbox_inches='tight')
# Also save PDF version
if save_path.endswith('.png'):
pdf_path = save_path.replace('.png', '.pdf')
plt.savefig(pdf_path, bbox_inches='tight')
print(f"Plot saved to: {save_path}")
return results_df, fig
except Exception as e:
raise Exception(f"Error plotting sequence prediction: {str(e)}")
def predict_regions(self, sequences, short_threshold=0.1, ensemble_weight=0.4):
'''
Predict region sequences (batch prediction of known 399bp sequences)
Args:
sequences: 399bp sequences or DataFrame/Series/list containing 399bp sequences
short_threshold: Short model probability threshold (default: 0.1)
ensemble_weight: Weight of short model in ensemble (default: 0.4)
Returns:
DataFrame: DataFrame containing prediction probabilities for all sequences
'''
try:
# Validate weight parameter
if not (0.0 <= ensemble_weight <= 1.0):
raise ValueError("ensemble_weight must be between 0.0 and 1.0")
# Unify input format
if isinstance(sequences, (pd.DataFrame, pd.Series)):
sequences = sequences.tolist()
elif isinstance(sequences, str):
sequences = [sequences]
results = []
for i, seq399 in enumerate(sequences):
try:
# Extract central 33bp from 399bp sequence (for short model use)
seq33 = self._extract_center_sequence(seq399, target_length=self.short_seq_length)
# Use unified prediction method
pred_result = self.predict_single_position(seq33, seq399, short_threshold, ensemble_weight)
pred_result.update({
'Short_Sequence': seq33,
'Long_Sequence': seq399
})
results.append(pred_result)
except Exception as e:
print(f"Error processing sequence {i+1}: {str(e)}")
long_weight = 1.0 - ensemble_weight
results.append({
'Short_Probability': 0.0,
'Long_Probability': 0.0,
'Ensemble_Probability': 0.0,
'Ensemble_Weights': f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}',
'Short_Sequence': self._extract_center_sequence(seq399, target_length=self.short_seq_length) if len(seq399) >= self.short_seq_length else seq399,
'Long_Sequence': seq399
})
return pd.DataFrame(results)
except Exception as e:
raise Exception(f"Error in region prediction process: {str(e)}")
def _extract_center_sequence(self, sequence, target_length=33):
"""Extract subsequence of specified length from center position of sequence"""
# Ensure sequence is string
sequence = str(sequence).upper()
# If sequence length is less than target length, return original sequence
if len(sequence) <= target_length:
return sequence
# Calculate center position
center = len(sequence) // 2
half_target = target_length // 2
# Extract center sequence
start = center - half_target
end = start + target_length
# Boundary check
if start < 0:
start = 0
end = target_length
elif end > len(sequence):
end = len(sequence)
start = end - target_length
return sequence[start:end]
# 兼容性方法(向后兼容,但标记为废弃)
def predict_full(self, sequence, window_size=3, short_threshold=0.1, short_weight=0.4, plot=False):
"""
⚠️ 已废弃:请使用 predict_sequence() 方法
向后兼容的方法,内部调用新的 predict_sequence()
"""
import warnings
warnings.warn("predict_full() 已废弃,请使用 predict_sequence() 方法", DeprecationWarning, stacklevel=2)
# 调用新方法并添加兼容性字段
results_df = self.predict_sequence(sequence, window_size, short_threshold, short_weight)
# 添加兼容性字段
if 'Ensemble_Probability' in results_df.columns:
results_df['Voting_Probability'] = results_df['Ensemble_Probability']
results_df['Weighted_Probability'] = results_df['Ensemble_Probability']
if 'Ensemble_Weights' in results_df.columns:
results_df['Weight_Info'] = results_df['Ensemble_Weights']
if 'Short_Sequence' in results_df.columns:
results_df['33bp'] = results_df['Short_Sequence']
if 'Long_Sequence' in results_df.columns:
results_df['399bp'] = results_df['Long_Sequence']
if plot:
# 如果需要绘图,调用绘图方法
_, fig = self.plot_sequence_prediction(sequence, window_size, 0.65, 0.8, short_weight)
return results_df, fig
return results_df
def predict_region(self, seq, short_threshold=0.1, short_weight=0.4):
"""
⚠️ 已废弃:请使用 predict_regions() 方法
向后兼容的方法,内部调用新的 predict_regions()
"""
import warnings
warnings.warn("predict_region() 已废弃,请使用 predict_regions() 方法", DeprecationWarning, stacklevel=2)
# 调用新方法并添加兼容性字段
results_df = self.predict_regions(seq, short_threshold, short_weight)
# 添加兼容性字段
if 'Ensemble_Probability' in results_df.columns:
results_df['Voting_Probability'] = results_df['Ensemble_Probability']
results_df['Weighted_Probability'] = results_df['Ensemble_Probability']
if 'Ensemble_Weights' in results_df.columns:
results_df['Weight_Info'] = results_df['Ensemble_Weights']
if 'Short_Sequence' in results_df.columns:
results_df['33bp'] = results_df['Short_Sequence']
if 'Long_Sequence' in results_df.columns:
results_df['399bp'] = results_df['Long_Sequence']
return results_df