优化绘图,增加使用介绍的ipynb
This commit is contained in:
parent
3833704c47
commit
089df9c4a6
|
|
@ -19,61 +19,61 @@ def predict_prf(
|
|||
model_dir: str = None
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
PRF位点预测函数
|
||||
PRF site prediction function
|
||||
|
||||
Args:
|
||||
sequence: 单个或多个DNA序列,用于滑动窗口预测
|
||||
data: DataFrame数据,必须包含'Long_Sequence'或'399bp'列,用于区域预测
|
||||
window_size: 滑动窗口大小(默认为3)
|
||||
short_threshold: Short模型(HistGB)概率阈值(默认为0.1)
|
||||
ensemble_weight: Short模型在集成中的权重(默认为0.4,Long权重为0.6)
|
||||
model_dir: 模型文件目录路径(可选)
|
||||
sequence: Single or multiple DNA sequences for sliding window prediction
|
||||
data: DataFrame data, must contain 'Long_Sequence' or '399bp' column for region prediction
|
||||
window_size: Sliding window size (default: 3)
|
||||
short_threshold: Short model (HistGB) probability threshold (default: 0.1)
|
||||
ensemble_weight: Weight of short model in ensemble (default: 0.4, long weight: 0.6)
|
||||
model_dir: Model directory path (optional)
|
||||
|
||||
Returns:
|
||||
pandas.DataFrame: 预测结果,包含以下主要字段:
|
||||
- Short_Probability: Short模型预测概率
|
||||
- Long_Probability: Long模型预测概率
|
||||
- Ensemble_Probability: 集成预测概率(主要结果)
|
||||
- Ensemble_Weights: 权重配置信息
|
||||
pandas.DataFrame: Prediction results containing the following main fields:
|
||||
- Short_Probability: Short model prediction probability
|
||||
- Long_Probability: Long model prediction probability
|
||||
- Ensemble_Probability: Ensemble prediction probability (main result)
|
||||
- Ensemble_Weights: Weight configuration information
|
||||
|
||||
Examples:
|
||||
# 1. 单条序列滑动窗口预测
|
||||
# 1. Single sequence sliding window prediction
|
||||
>>> from FScanpy import predict_prf
|
||||
>>> sequence = "ATGCGTACGT..."
|
||||
>>> results = predict_prf(sequence=sequence)
|
||||
|
||||
# 2. 多条序列滑动窗口预测
|
||||
# 2. Multiple sequences sliding window prediction
|
||||
>>> sequences = ["ATGCGTACGT...", "GCTATAGCAT..."]
|
||||
>>> results = predict_prf(sequence=sequences)
|
||||
|
||||
# 3. 自定义集成权重比例
|
||||
>>> results = predict_prf(sequence=sequence, ensemble_weight=0.3) # 3:7 比例
|
||||
# 3. Custom ensemble weight ratio
|
||||
>>> results = predict_prf(sequence=sequence, ensemble_weight=0.3) # 3:7 ratio
|
||||
|
||||
# 4. DataFrame区域预测
|
||||
# 4. DataFrame region prediction
|
||||
>>> import pandas as pd
|
||||
>>> data = pd.DataFrame({
|
||||
... 'Long_Sequence': ['ATGCGT...', 'GCTATAG...'] # 或使用 '399bp'
|
||||
... 'Long_Sequence': ['ATGCGT...', 'GCTATAG...'] # or use '399bp'
|
||||
... })
|
||||
>>> results = predict_prf(data=data)
|
||||
"""
|
||||
predictor = PRFPredictor(model_dir=model_dir)
|
||||
|
||||
# 验证输入参数
|
||||
# Validate input parameters
|
||||
if sequence is None and data is None:
|
||||
raise ValueError("必须提供sequence或data参数之一")
|
||||
raise ValueError("Must provide either sequence or data parameter")
|
||||
if sequence is not None and data is not None:
|
||||
raise ValueError("sequence和data参数不能同时提供")
|
||||
raise ValueError("Cannot provide both sequence and data parameters")
|
||||
if not (0.0 <= ensemble_weight <= 1.0):
|
||||
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
|
||||
raise ValueError("ensemble_weight must be between 0.0 and 1.0")
|
||||
|
||||
# 滑动窗口预测模式
|
||||
# Sliding window prediction mode
|
||||
if sequence is not None:
|
||||
if isinstance(sequence, str):
|
||||
# 单条序列预测
|
||||
# Single sequence prediction
|
||||
return predictor.predict_sequence(
|
||||
sequence, window_size, short_threshold, ensemble_weight)
|
||||
elif isinstance(sequence, (list, tuple)):
|
||||
# 多条序列预测
|
||||
# Multiple sequences prediction
|
||||
results = []
|
||||
for i, seq in enumerate(sequence, 1):
|
||||
try:
|
||||
|
|
@ -82,29 +82,29 @@ def predict_prf(
|
|||
result['Sequence_ID'] = f'seq_{i}'
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
print(f"警告:序列 {i} 预测失败 - {str(e)}")
|
||||
print(f"Warning: Sequence {i} prediction failed - {str(e)}")
|
||||
return pd.concat(results, ignore_index=True) if results else pd.DataFrame()
|
||||
|
||||
# 区域化预测模式
|
||||
# Region prediction mode
|
||||
else:
|
||||
if not isinstance(data, pd.DataFrame):
|
||||
raise ValueError("data参数必须是pandas DataFrame类型")
|
||||
raise ValueError("data parameter must be pandas DataFrame type")
|
||||
|
||||
# 检查列名(支持新旧两种命名)
|
||||
# Check column names (support both new and old naming conventions)
|
||||
seq_column = None
|
||||
if 'Long_Sequence' in data.columns:
|
||||
seq_column = 'Long_Sequence'
|
||||
elif '399bp' in data.columns:
|
||||
seq_column = '399bp'
|
||||
else:
|
||||
raise ValueError("DataFrame必须包含'Long_Sequence'或'399bp'列")
|
||||
raise ValueError("DataFrame must contain 'Long_Sequence' or '399bp' column")
|
||||
|
||||
# 调用区域预测函数
|
||||
# Call region prediction function
|
||||
try:
|
||||
results = predictor.predict_regions(
|
||||
data[seq_column], short_threshold, ensemble_weight)
|
||||
|
||||
# 添加原始数据的其他列
|
||||
# Add other columns from original data
|
||||
for col in data.columns:
|
||||
if col not in ['Long_Sequence', '399bp', 'Short_Sequence', '33bp']:
|
||||
results[col] = data[col].values
|
||||
|
|
@ -112,8 +112,8 @@ def predict_prf(
|
|||
return results
|
||||
|
||||
except Exception as e:
|
||||
print(f"警告:区域预测失败 - {str(e)}")
|
||||
# 创建空结果
|
||||
print(f"Warning: Region prediction failed - {str(e)}")
|
||||
# Create empty results
|
||||
long_weight = 1.0 - ensemble_weight
|
||||
results = pd.DataFrame({
|
||||
'Short_Probability': [0.0] * len(data),
|
||||
|
|
@ -122,7 +122,7 @@ def predict_prf(
|
|||
'Ensemble_Weights': [f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}'] * len(data)
|
||||
})
|
||||
|
||||
# 添加原始数据列
|
||||
# Add original data columns
|
||||
for col in data.columns:
|
||||
results[col] = data[col].values
|
||||
|
||||
|
|
@ -136,59 +136,59 @@ def plot_prf_prediction(
|
|||
ensemble_weight: float = 0.4,
|
||||
title: str = None,
|
||||
save_path: str = None,
|
||||
figsize: tuple = (12, 6),
|
||||
figsize: tuple = (12, 8),
|
||||
dpi: int = 300,
|
||||
model_dir: str = None
|
||||
) -> tuple:
|
||||
"""
|
||||
绘制序列PRF预测结果的移码概率图
|
||||
Plot PRF prediction results for sequence frameshifting probability
|
||||
|
||||
Args:
|
||||
sequence: 输入DNA序列
|
||||
window_size: 滑动窗口大小(默认为3)
|
||||
short_threshold: Short模型(HistGB)过滤阈值(默认为0.65)
|
||||
long_threshold: Long模型(BiLSTM-CNN)过滤阈值(默认为0.8)
|
||||
ensemble_weight: Short模型在集成中的权重(默认为0.4,Long权重为0.6)
|
||||
title: 图片标题(可选)
|
||||
save_path: 保存路径(可选,如果提供则保存图片)
|
||||
figsize: 图片尺寸(默认为(12, 6))
|
||||
dpi: 图片分辨率(默认为300)
|
||||
model_dir: 模型文件目录路径(可选)
|
||||
sequence: Input DNA sequence
|
||||
window_size: Sliding window size (default: 3)
|
||||
short_threshold: Short model (HistGB) filtering threshold (default: 0.65)
|
||||
long_threshold: Long model (BiLSTM-CNN) filtering threshold (default: 0.8)
|
||||
ensemble_weight: Weight of short model in ensemble (default: 0.4, long weight: 0.6)
|
||||
title: Plot title (optional)
|
||||
save_path: Save path (optional, saves plot if provided)
|
||||
figsize: Figure size (default: (12, 8))
|
||||
dpi: Figure resolution (default: 300)
|
||||
model_dir: Model directory path (optional)
|
||||
|
||||
Returns:
|
||||
tuple: (pd.DataFrame, matplotlib.figure.Figure) 预测结果和图形对象
|
||||
tuple: (pd.DataFrame, matplotlib.figure.Figure) prediction results and figure object
|
||||
|
||||
Examples:
|
||||
# 1. 简单绘图
|
||||
# 1. Simple plotting
|
||||
>>> from FScanpy import plot_prf_prediction
|
||||
>>> sequence = "ATGCGTACGT..."
|
||||
>>> results, fig = plot_prf_prediction(sequence)
|
||||
>>> plt.show()
|
||||
|
||||
# 2. 自定义阈值和集成权重
|
||||
# 2. Custom thresholds and ensemble weights
|
||||
>>> results, fig = plot_prf_prediction(
|
||||
... sequence,
|
||||
... short_threshold=0.7,
|
||||
... long_threshold=0.85,
|
||||
... ensemble_weight=0.3, # 3:7 权重比例
|
||||
... title="自定义权重的预测结果",
|
||||
... ensemble_weight=0.3, # 3:7 weight ratio
|
||||
... title="Custom Weight Prediction Results",
|
||||
... save_path="prediction_result.png"
|
||||
... )
|
||||
|
||||
# 3. 等权重组合
|
||||
# 3. Equal weight combination
|
||||
>>> results, fig = plot_prf_prediction(
|
||||
... sequence,
|
||||
... ensemble_weight=0.5 # 5:5 等权重
|
||||
... ensemble_weight=0.5 # 5:5 equal weights
|
||||
... )
|
||||
|
||||
# 4. Long模型主导
|
||||
# 4. Long model dominated
|
||||
>>> results, fig = plot_prf_prediction(
|
||||
... sequence,
|
||||
... ensemble_weight=0.2 # 2:8 权重,Long模型主导
|
||||
... ensemble_weight=0.2 # 2:8 weights, long model dominated
|
||||
... )
|
||||
"""
|
||||
if not (0.0 <= ensemble_weight <= 1.0):
|
||||
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
|
||||
raise ValueError("ensemble_weight must be between 0.0 and 1.0")
|
||||
|
||||
predictor = PRFPredictor(model_dir=model_dir)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,487 +1,499 @@
|
|||
import os
|
||||
import pickle
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tensorflow.keras.models import load_model
|
||||
from .features.sequence import SequenceFeatureExtractor
|
||||
from .features.cnn_input import CNNInputProcessor
|
||||
from .utils import extract_window_sequences
|
||||
import matplotlib.pyplot as plt
|
||||
import joblib
|
||||
|
||||
|
||||
class PRFPredictor:
|
||||
|
||||
def __init__(self, model_dir=None):
|
||||
"""
|
||||
初始化PRF预测器
|
||||
|
||||
Args:
|
||||
model_dir: 模型目录路径(可选)
|
||||
"""
|
||||
if model_dir is None:
|
||||
from pkg_resources import resource_filename
|
||||
model_dir = resource_filename('FScanpy', 'pretrained')
|
||||
|
||||
try:
|
||||
# 加载模型 - 使用新的命名约定
|
||||
self.short_model = self._load_pickle(os.path.join(model_dir, 'short.pkl')) # HistGB模型
|
||||
self.long_model = self._load_pickle(os.path.join(model_dir, 'long.pkl')) # BiLSTM-CNN模型
|
||||
|
||||
# 初始化特征提取器和CNN处理器,使用与训练时相同的序列长度
|
||||
self.short_seq_length = 33 # HistGB使用的序列长度
|
||||
self.long_seq_length = 399 # BiLSTM-CNN使用的序列长度
|
||||
|
||||
# 初始化特征提取器和CNN输入处理器
|
||||
self.feature_extractor = SequenceFeatureExtractor(seq_length=self.short_seq_length)
|
||||
self.cnn_processor = CNNInputProcessor(max_length=self.long_seq_length)
|
||||
|
||||
# 检测模型类型以优化预测性能
|
||||
self._detect_model_types()
|
||||
|
||||
except FileNotFoundError as e:
|
||||
raise FileNotFoundError(f"无法找到模型文件: {str(e)}。请确保模型文件 'short.pkl' 和 'long.pkl' 存在于 {model_dir}")
|
||||
except Exception as e:
|
||||
raise Exception(f"加载模型出错: {str(e)}")
|
||||
|
||||
def _load_pickle(self, path):
|
||||
"""安全加载pickle文件"""
|
||||
try:
|
||||
return joblib.load(path)
|
||||
except Exception as e:
|
||||
raise FileNotFoundError(f"无法加载模型文件 {path}: {str(e)}")
|
||||
|
||||
def _detect_model_types(self):
|
||||
"""检测模型类型以优化预测性能"""
|
||||
self.short_is_sklearn = hasattr(self.short_model, 'predict_proba')
|
||||
self.long_is_sklearn = hasattr(self.long_model, 'predict_proba')
|
||||
|
||||
def _predict_model(self, model, features, is_sklearn, seq_length):
|
||||
"""统一的模型预测方法"""
|
||||
try:
|
||||
if is_sklearn:
|
||||
# sklearn模型使用特征向量
|
||||
if isinstance(features, np.ndarray) and features.ndim > 1:
|
||||
features = features.flatten()
|
||||
features_2d = np.array([features])
|
||||
pred = model.predict_proba(features_2d)
|
||||
return pred[0][1]
|
||||
else:
|
||||
# 深度学习模型
|
||||
if seq_length == self.long_seq_length:
|
||||
# 对于长序列,使用CNN处理器
|
||||
model_input = self.cnn_processor.prepare_sequence(features)
|
||||
else:
|
||||
# 对于短序列,转换为数值编码
|
||||
base_to_num = {'A': 1, 'T': 2, 'G': 3, 'C': 4, 'N': 0}
|
||||
seq_numeric = [base_to_num.get(base, 0) for base in features.upper()]
|
||||
model_input = np.array(seq_numeric).reshape(1, len(seq_numeric), 1)
|
||||
|
||||
# 统一的预测调用
|
||||
try:
|
||||
pred = model.predict(model_input, verbose=0)
|
||||
except TypeError:
|
||||
pred = model.predict(model_input)
|
||||
|
||||
# 处理预测结果
|
||||
if isinstance(pred, list):
|
||||
pred = pred[0]
|
||||
|
||||
if hasattr(pred, 'shape') and len(pred.shape) > 1 and pred.shape[1] > 1:
|
||||
return pred[0][1]
|
||||
else:
|
||||
return pred[0][0] if hasattr(pred[0], '__getitem__') else pred[0]
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"模型预测失败: {str(e)}")
|
||||
|
||||
def predict_single_position(self, fs_period, full_seq, short_threshold=0.1, ensemble_weight=0.4):
|
||||
'''
|
||||
预测单个位置的PRF状态
|
||||
|
||||
Args:
|
||||
fs_period: 33bp序列 (short模型使用)
|
||||
full_seq: 完整序列 (long模型使用)
|
||||
short_threshold: short模型的概率阈值 (默认为0.1)
|
||||
ensemble_weight: short模型在集成中的权重 (默认为0.4,long权重为0.6)
|
||||
Returns:
|
||||
dict: 包含预测概率的字典
|
||||
'''
|
||||
try:
|
||||
# 验证权重参数
|
||||
if not (0.0 <= ensemble_weight <= 1.0):
|
||||
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
|
||||
|
||||
long_weight = 1.0 - ensemble_weight
|
||||
|
||||
# 处理序列长度
|
||||
if len(fs_period) > self.short_seq_length:
|
||||
fs_period = self.feature_extractor.trim_sequence(fs_period, self.short_seq_length)
|
||||
|
||||
# Short模型预测 (HistGB)
|
||||
try:
|
||||
if self.short_is_sklearn:
|
||||
short_features = self.feature_extractor.extract_features(fs_period)
|
||||
short_prob = self._predict_model(self.short_model, short_features, True, self.short_seq_length)
|
||||
else:
|
||||
short_prob = self._predict_model(self.short_model, fs_period, False, self.short_seq_length)
|
||||
except Exception as e:
|
||||
print(f"Short模型预测时出错: {str(e)}")
|
||||
short_prob = 0.0
|
||||
|
||||
# 如果short概率低于阈值,则跳过long模型
|
||||
if short_prob < short_threshold:
|
||||
return {
|
||||
'Short_Probability': short_prob,
|
||||
'Long_Probability': 0.0,
|
||||
'Ensemble_Probability': 0.0,
|
||||
'Ensemble_Weights': f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}'
|
||||
}
|
||||
|
||||
# Long模型预测 (BiLSTM-CNN)
|
||||
try:
|
||||
if self.long_is_sklearn:
|
||||
long_features = self.feature_extractor.extract_features(full_seq)
|
||||
long_prob = self._predict_model(self.long_model, long_features, True, self.long_seq_length)
|
||||
else:
|
||||
long_prob = self._predict_model(self.long_model, full_seq, False, self.long_seq_length)
|
||||
except Exception as e:
|
||||
print(f"Long模型预测时出错: {str(e)}")
|
||||
long_prob = 0.0
|
||||
|
||||
# 计算集成概率
|
||||
try:
|
||||
ensemble_prob = ensemble_weight * short_prob + long_weight * long_prob
|
||||
except Exception as e:
|
||||
print(f"计算集成概率时出错: {str(e)}")
|
||||
ensemble_prob = (short_prob + long_prob) / 2
|
||||
|
||||
return {
|
||||
'Short_Probability': short_prob,
|
||||
'Long_Probability': long_prob,
|
||||
'Ensemble_Probability': ensemble_prob,
|
||||
'Ensemble_Weights': f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}'
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"预测过程出错: {str(e)}")
|
||||
|
||||
def predict_sequence(self, sequence, window_size=3, short_threshold=0.1, ensemble_weight=0.4):
|
||||
"""
|
||||
预测完整序列中的PRF位点(滑动窗口方法)
|
||||
|
||||
Args:
|
||||
sequence: 输入DNA序列
|
||||
window_size: 滑动窗口大小 (默认为3)
|
||||
short_threshold: short模型概率阈值 (默认为0.1)
|
||||
ensemble_weight: short模型在集成中的权重 (默认为0.4)
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: 包含预测结果的DataFrame
|
||||
"""
|
||||
if window_size < 1:
|
||||
raise ValueError("窗口大小必须大于等于1")
|
||||
if short_threshold < 0:
|
||||
raise ValueError("short模型阈值必须大于等于0")
|
||||
if not (0.0 <= ensemble_weight <= 1.0):
|
||||
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
|
||||
|
||||
results = []
|
||||
long_weight = 1.0 - ensemble_weight
|
||||
|
||||
try:
|
||||
# 确保序列为字符串并转换为大写
|
||||
sequence = str(sequence).upper()
|
||||
|
||||
# 滑动窗口预测
|
||||
for pos in range(0, len(sequence) - 2, window_size):
|
||||
# 提取窗口序列
|
||||
fs_period, full_seq = extract_window_sequences(sequence, pos)
|
||||
|
||||
if fs_period is None or full_seq is None:
|
||||
continue
|
||||
|
||||
# 预测并记录结果
|
||||
pred = self.predict_single_position(fs_period, full_seq, short_threshold, ensemble_weight)
|
||||
pred.update({
|
||||
'Position': pos,
|
||||
'Codon': sequence[pos:pos+3],
|
||||
'Short_Sequence': fs_period, # 更清晰的命名
|
||||
'Long_Sequence': full_seq # 更清晰的命名
|
||||
})
|
||||
results.append(pred)
|
||||
|
||||
# 创建结果DataFrame
|
||||
results_df = pd.DataFrame(results)
|
||||
|
||||
return results_df
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"序列预测过程出错: {str(e)}")
|
||||
|
||||
def plot_sequence_prediction(self, sequence, window_size=3, short_threshold=0.65,
|
||||
long_threshold=0.8, ensemble_weight=0.4, title=None, save_path=None,
|
||||
figsize=(12, 6), dpi=300):
|
||||
"""
|
||||
绘制序列预测结果的移码概率图
|
||||
|
||||
Args:
|
||||
sequence: 输入DNA序列
|
||||
window_size: 滑动窗口大小 (默认为3)
|
||||
short_threshold: Short模型(HistGB)过滤阈值 (默认为0.65)
|
||||
long_threshold: Long模型(BiLSTM-CNN)过滤阈值 (默认为0.8)
|
||||
ensemble_weight: Short模型在集成中的权重 (默认为0.4)
|
||||
title: 图片标题 (可选)
|
||||
save_path: 保存路径 (可选,如果提供则保存图片)
|
||||
figsize: 图片尺寸 (默认为(12, 6))
|
||||
dpi: 图片分辨率 (默认为300)
|
||||
|
||||
Returns:
|
||||
tuple: (pd.DataFrame, matplotlib.figure.Figure) 预测结果和图形对象
|
||||
"""
|
||||
try:
|
||||
# 验证权重参数
|
||||
if not (0.0 <= ensemble_weight <= 1.0):
|
||||
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
|
||||
|
||||
long_weight = 1.0 - ensemble_weight
|
||||
|
||||
# 获取预测结果 - 使用新的方法名
|
||||
results_df = self.predict_sequence(sequence, window_size=window_size,
|
||||
short_threshold=0.1, ensemble_weight=ensemble_weight)
|
||||
|
||||
if results_df.empty:
|
||||
raise ValueError("预测结果为空,请检查输入序列")
|
||||
|
||||
# 获取序列长度
|
||||
seq_length = len(sequence)
|
||||
|
||||
# 计算显示宽度
|
||||
prob_width = max(1, seq_length // 300) # 概率标记的宽度
|
||||
|
||||
# 创建图形,包含两个子图
|
||||
fig = plt.figure(figsize=figsize)
|
||||
gs = fig.add_gridspec(2, 1, height_ratios=[0.15, 1], hspace=0.3)
|
||||
|
||||
# 设置标题
|
||||
if title:
|
||||
fig.suptitle(title, y=0.95, fontsize=12)
|
||||
else:
|
||||
fig.suptitle(f'序列移码概率预测结果 (权重 {ensemble_weight:.1f}:{long_weight:.1f})', y=0.95, fontsize=12)
|
||||
|
||||
# 预测概率热图
|
||||
ax0 = fig.add_subplot(gs[0])
|
||||
prob_data = np.zeros((1, seq_length))
|
||||
|
||||
# 应用双重阈值过滤
|
||||
for _, row in results_df.iterrows():
|
||||
pos = int(row['Position'])
|
||||
if (row['Short_Probability'] >= short_threshold and
|
||||
row['Long_Probability'] >= long_threshold):
|
||||
# 为每个满足阈值的位置设置概率值
|
||||
start = max(0, pos - prob_width//2)
|
||||
end = min(seq_length, pos + prob_width//2 + 1)
|
||||
prob_data[0, start:end] = row['Ensemble_Probability']
|
||||
|
||||
im = ax0.imshow(prob_data, cmap='Reds', aspect='auto', vmin=0, vmax=1,
|
||||
interpolation='nearest')
|
||||
ax0.set_xticks([])
|
||||
ax0.set_yticks([])
|
||||
ax0.set_title(f'预测概率热图 (Short≥{short_threshold}, Long≥{long_threshold})',
|
||||
pad=5, fontsize=10)
|
||||
|
||||
# 主图(条形图)
|
||||
ax1 = fig.add_subplot(gs[1])
|
||||
|
||||
# 应用过滤阈值
|
||||
filtered_probs = results_df['Ensemble_Probability'].copy()
|
||||
mask = ((results_df['Short_Probability'] < short_threshold) |
|
||||
(results_df['Long_Probability'] < long_threshold))
|
||||
filtered_probs[mask] = 0
|
||||
|
||||
# 绘制条形图
|
||||
bars = ax1.bar(results_df['Position'], filtered_probs,
|
||||
alpha=0.7, color='darkred', width=max(1, window_size))
|
||||
|
||||
# 设置x轴刻度
|
||||
step = max(seq_length // 10, 50)
|
||||
x_ticks = np.arange(0, seq_length, step)
|
||||
ax1.set_xticks(x_ticks)
|
||||
ax1.tick_params(axis='x', rotation=45)
|
||||
|
||||
# 设置标签和标题
|
||||
ax1.set_xlabel('序列位置 (bp)', fontsize=10)
|
||||
ax1.set_ylabel('移码概率', fontsize=10)
|
||||
ax1.set_title(f'移码概率分布 (集成权重 {ensemble_weight:.1f}:{long_weight:.1f})', fontsize=11)
|
||||
|
||||
# 设置y轴范围
|
||||
ax1.set_ylim(0, 1)
|
||||
|
||||
# 添加网格
|
||||
ax1.grid(True, alpha=0.3)
|
||||
|
||||
# 添加阈值和权重说明
|
||||
info_text = (f'过滤阈值: Short≥{short_threshold}, Long≥{long_threshold}\n'
|
||||
f'集成权重: Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}')
|
||||
ax1.text(0.02, 0.95, info_text, transform=ax1.transAxes,
|
||||
fontsize=9, verticalalignment='top',
|
||||
bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8))
|
||||
|
||||
# 确保所有子图的x轴范围一致
|
||||
for ax in [ax0, ax1]:
|
||||
ax.set_xlim(-1, seq_length)
|
||||
|
||||
# 调整布局
|
||||
plt.tight_layout()
|
||||
|
||||
# 如果提供了保存路径,则保存图片
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=dpi, bbox_inches='tight')
|
||||
# 同时保存PDF版本
|
||||
if save_path.endswith('.png'):
|
||||
pdf_path = save_path.replace('.png', '.pdf')
|
||||
plt.savefig(pdf_path, bbox_inches='tight')
|
||||
print(f"图片已保存至: {save_path}")
|
||||
|
||||
return results_df, fig
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"绘制序列预测图时出错: {str(e)}")
|
||||
|
||||
def predict_regions(self, sequences, short_threshold=0.1, ensemble_weight=0.4):
|
||||
'''
|
||||
预测区域序列(批量预测已知的399bp序列)
|
||||
|
||||
Args:
|
||||
sequences: 399bp序列或包含399bp序列的DataFrame/Series/list
|
||||
short_threshold: short模型概率阈值 (默认为0.1)
|
||||
ensemble_weight: short模型在集成中的权重 (默认为0.4)
|
||||
|
||||
Returns:
|
||||
DataFrame: 包含所有序列预测概率的DataFrame
|
||||
'''
|
||||
try:
|
||||
# 验证权重参数
|
||||
if not (0.0 <= ensemble_weight <= 1.0):
|
||||
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
|
||||
|
||||
# 统一输入格式
|
||||
if isinstance(sequences, (pd.DataFrame, pd.Series)):
|
||||
sequences = sequences.tolist()
|
||||
elif isinstance(sequences, str):
|
||||
sequences = [sequences]
|
||||
|
||||
results = []
|
||||
for i, seq399 in enumerate(sequences):
|
||||
try:
|
||||
# 从399bp序列中截取中心的33bp (short模型使用)
|
||||
seq33 = self._extract_center_sequence(seq399, target_length=self.short_seq_length)
|
||||
|
||||
# 使用统一的预测方法
|
||||
pred_result = self.predict_single_position(seq33, seq399, short_threshold, ensemble_weight)
|
||||
pred_result.update({
|
||||
'Short_Sequence': seq33,
|
||||
'Long_Sequence': seq399
|
||||
})
|
||||
|
||||
results.append(pred_result)
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理第 {i+1} 个序列时出错: {str(e)}")
|
||||
long_weight = 1.0 - ensemble_weight
|
||||
results.append({
|
||||
'Short_Probability': 0.0,
|
||||
'Long_Probability': 0.0,
|
||||
'Ensemble_Probability': 0.0,
|
||||
'Ensemble_Weights': f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}',
|
||||
'Short_Sequence': self._extract_center_sequence(seq399, target_length=self.short_seq_length) if len(seq399) >= self.short_seq_length else seq399,
|
||||
'Long_Sequence': seq399
|
||||
})
|
||||
|
||||
return pd.DataFrame(results)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"区域预测过程出错: {str(e)}")
|
||||
|
||||
def _extract_center_sequence(self, sequence, target_length=33):
|
||||
"""从序列中心位置提取指定长度的子序列"""
|
||||
# 确保序列为字符串
|
||||
sequence = str(sequence).upper()
|
||||
|
||||
# 如果序列长度小于目标长度,返回原序列
|
||||
if len(sequence) <= target_length:
|
||||
return sequence
|
||||
|
||||
# 计算中心位置
|
||||
center = len(sequence) // 2
|
||||
half_target = target_length // 2
|
||||
|
||||
# 提取中心序列
|
||||
start = center - half_target
|
||||
end = start + target_length
|
||||
|
||||
# 边界检查
|
||||
if start < 0:
|
||||
start = 0
|
||||
end = target_length
|
||||
elif end > len(sequence):
|
||||
end = len(sequence)
|
||||
start = end - target_length
|
||||
|
||||
return sequence[start:end]
|
||||
|
||||
# 兼容性方法(向后兼容,但标记为废弃)
|
||||
def predict_full(self, sequence, window_size=3, short_threshold=0.1, short_weight=0.4, plot=False):
|
||||
"""
|
||||
⚠️ 已废弃:请使用 predict_sequence() 方法
|
||||
|
||||
向后兼容的方法,内部调用新的 predict_sequence()
|
||||
"""
|
||||
import warnings
|
||||
warnings.warn("predict_full() 已废弃,请使用 predict_sequence() 方法", DeprecationWarning, stacklevel=2)
|
||||
|
||||
# 调用新方法并添加兼容性字段
|
||||
results_df = self.predict_sequence(sequence, window_size, short_threshold, short_weight)
|
||||
|
||||
# 添加兼容性字段
|
||||
if 'Ensemble_Probability' in results_df.columns:
|
||||
results_df['Voting_Probability'] = results_df['Ensemble_Probability']
|
||||
results_df['Weighted_Probability'] = results_df['Ensemble_Probability']
|
||||
if 'Ensemble_Weights' in results_df.columns:
|
||||
results_df['Weight_Info'] = results_df['Ensemble_Weights']
|
||||
if 'Short_Sequence' in results_df.columns:
|
||||
results_df['33bp'] = results_df['Short_Sequence']
|
||||
if 'Long_Sequence' in results_df.columns:
|
||||
results_df['399bp'] = results_df['Long_Sequence']
|
||||
|
||||
if plot:
|
||||
# 如果需要绘图,调用绘图方法
|
||||
_, fig = self.plot_sequence_prediction(sequence, window_size, 0.65, 0.8, short_weight)
|
||||
return results_df, fig
|
||||
|
||||
return results_df
|
||||
|
||||
def predict_region(self, seq, short_threshold=0.1, short_weight=0.4):
|
||||
"""
|
||||
⚠️ 已废弃:请使用 predict_regions() 方法
|
||||
|
||||
向后兼容的方法,内部调用新的 predict_regions()
|
||||
"""
|
||||
import warnings
|
||||
warnings.warn("predict_region() 已废弃,请使用 predict_regions() 方法", DeprecationWarning, stacklevel=2)
|
||||
|
||||
# 调用新方法并添加兼容性字段
|
||||
results_df = self.predict_regions(seq, short_threshold, short_weight)
|
||||
|
||||
# 添加兼容性字段
|
||||
if 'Ensemble_Probability' in results_df.columns:
|
||||
results_df['Voting_Probability'] = results_df['Ensemble_Probability']
|
||||
results_df['Weighted_Probability'] = results_df['Ensemble_Probability']
|
||||
if 'Ensemble_Weights' in results_df.columns:
|
||||
results_df['Weight_Info'] = results_df['Ensemble_Weights']
|
||||
if 'Short_Sequence' in results_df.columns:
|
||||
results_df['33bp'] = results_df['Short_Sequence']
|
||||
if 'Long_Sequence' in results_df.columns:
|
||||
results_df['399bp'] = results_df['Long_Sequence']
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tensorflow.keras.models import load_model
|
||||
from .features.sequence import SequenceFeatureExtractor
|
||||
from .features.cnn_input import CNNInputProcessor
|
||||
from .utils import extract_window_sequences
|
||||
import matplotlib.pyplot as plt
|
||||
import joblib
|
||||
|
||||
|
||||
class PRFPredictor:
|
||||
|
||||
def __init__(self, model_dir=None):
|
||||
"""
|
||||
初始化PRF预测器
|
||||
|
||||
Args:
|
||||
model_dir: 模型目录路径(可选)
|
||||
"""
|
||||
if model_dir is None:
|
||||
from pkg_resources import resource_filename
|
||||
model_dir = resource_filename('FScanpy', 'pretrained')
|
||||
|
||||
try:
|
||||
# 加载模型 - 使用新的命名约定
|
||||
self.short_model = self._load_pickle(os.path.join(model_dir, 'short.pkl')) # HistGB模型
|
||||
self.long_model = self._load_pickle(os.path.join(model_dir, 'long.pkl')) # BiLSTM-CNN模型
|
||||
|
||||
# 初始化特征提取器和CNN处理器,使用与训练时相同的序列长度
|
||||
self.short_seq_length = 33 # HistGB使用的序列长度
|
||||
self.long_seq_length = 399 # BiLSTM-CNN使用的序列长度
|
||||
|
||||
# 初始化特征提取器和CNN输入处理器
|
||||
self.feature_extractor = SequenceFeatureExtractor(seq_length=self.short_seq_length)
|
||||
self.cnn_processor = CNNInputProcessor(max_length=self.long_seq_length)
|
||||
|
||||
# 检测模型类型以优化预测性能
|
||||
self._detect_model_types()
|
||||
|
||||
except FileNotFoundError as e:
|
||||
raise FileNotFoundError(f"无法找到模型文件: {str(e)}。请确保模型文件 'short.pkl' 和 'long.pkl' 存在于 {model_dir}")
|
||||
except Exception as e:
|
||||
raise Exception(f"加载模型出错: {str(e)}")
|
||||
|
||||
def _load_pickle(self, path):
|
||||
"""安全加载pickle文件"""
|
||||
try:
|
||||
return joblib.load(path)
|
||||
except Exception as e:
|
||||
raise FileNotFoundError(f"无法加载模型文件 {path}: {str(e)}")
|
||||
|
||||
def _detect_model_types(self):
|
||||
"""检测模型类型以优化预测性能"""
|
||||
self.short_is_sklearn = hasattr(self.short_model, 'predict_proba')
|
||||
self.long_is_sklearn = hasattr(self.long_model, 'predict_proba')
|
||||
|
||||
def _predict_model(self, model, features, is_sklearn, seq_length):
|
||||
"""统一的模型预测方法"""
|
||||
try:
|
||||
if is_sklearn:
|
||||
# sklearn模型使用特征向量
|
||||
if isinstance(features, np.ndarray) and features.ndim > 1:
|
||||
features = features.flatten()
|
||||
features_2d = np.array([features])
|
||||
pred = model.predict_proba(features_2d)
|
||||
return pred[0][1]
|
||||
else:
|
||||
# 深度学习模型
|
||||
if seq_length == self.long_seq_length:
|
||||
# 对于长序列,使用CNN处理器
|
||||
model_input = self.cnn_processor.prepare_sequence(features)
|
||||
else:
|
||||
# 对于短序列,转换为数值编码
|
||||
base_to_num = {'A': 1, 'T': 2, 'G': 3, 'C': 4, 'N': 0}
|
||||
seq_numeric = [base_to_num.get(base, 0) for base in features.upper()]
|
||||
model_input = np.array(seq_numeric).reshape(1, len(seq_numeric), 1)
|
||||
|
||||
# 统一的预测调用
|
||||
try:
|
||||
pred = model.predict(model_input, verbose=0)
|
||||
except TypeError:
|
||||
pred = model.predict(model_input)
|
||||
|
||||
# 处理预测结果
|
||||
if isinstance(pred, list):
|
||||
pred = pred[0]
|
||||
|
||||
if hasattr(pred, 'shape') and len(pred.shape) > 1 and pred.shape[1] > 1:
|
||||
return pred[0][1]
|
||||
else:
|
||||
return pred[0][0] if hasattr(pred[0], '__getitem__') else pred[0]
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"模型预测失败: {str(e)}")
|
||||
|
||||
def predict_single_position(self, fs_period, full_seq, short_threshold=0.1, ensemble_weight=0.4):
|
||||
'''
|
||||
预测单个位置的PRF状态
|
||||
|
||||
Args:
|
||||
fs_period: 33bp序列 (short模型使用)
|
||||
full_seq: 完整序列 (long模型使用)
|
||||
short_threshold: short模型的概率阈值 (默认为0.1)
|
||||
ensemble_weight: short模型在集成中的权重 (默认为0.4,long权重为0.6)
|
||||
Returns:
|
||||
dict: 包含预测概率的字典
|
||||
'''
|
||||
try:
|
||||
# 验证权重参数
|
||||
if not (0.0 <= ensemble_weight <= 1.0):
|
||||
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
|
||||
|
||||
long_weight = 1.0 - ensemble_weight
|
||||
|
||||
# 处理序列长度
|
||||
if len(fs_period) > self.short_seq_length:
|
||||
fs_period = self.feature_extractor.trim_sequence(fs_period, self.short_seq_length)
|
||||
|
||||
# Short模型预测 (HistGB)
|
||||
try:
|
||||
if self.short_is_sklearn:
|
||||
short_features = self.feature_extractor.extract_features(fs_period)
|
||||
short_prob = self._predict_model(self.short_model, short_features, True, self.short_seq_length)
|
||||
else:
|
||||
short_prob = self._predict_model(self.short_model, fs_period, False, self.short_seq_length)
|
||||
except Exception as e:
|
||||
print(f"Short模型预测时出错: {str(e)}")
|
||||
short_prob = 0.0
|
||||
|
||||
# 如果short概率低于阈值,则跳过long模型
|
||||
if short_prob < short_threshold:
|
||||
return {
|
||||
'Short_Probability': short_prob,
|
||||
'Long_Probability': 0.0,
|
||||
'Ensemble_Probability': 0.0,
|
||||
'Ensemble_Weights': f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}'
|
||||
}
|
||||
|
||||
# Long模型预测 (BiLSTM-CNN)
|
||||
try:
|
||||
if self.long_is_sklearn:
|
||||
long_features = self.feature_extractor.extract_features(full_seq)
|
||||
long_prob = self._predict_model(self.long_model, long_features, True, self.long_seq_length)
|
||||
else:
|
||||
long_prob = self._predict_model(self.long_model, full_seq, False, self.long_seq_length)
|
||||
except Exception as e:
|
||||
print(f"Long模型预测时出错: {str(e)}")
|
||||
long_prob = 0.0
|
||||
|
||||
# 计算集成概率
|
||||
try:
|
||||
ensemble_prob = ensemble_weight * short_prob + long_weight * long_prob
|
||||
except Exception as e:
|
||||
print(f"计算集成概率时出错: {str(e)}")
|
||||
ensemble_prob = (short_prob + long_prob) / 2
|
||||
|
||||
return {
|
||||
'Short_Probability': short_prob,
|
||||
'Long_Probability': long_prob,
|
||||
'Ensemble_Probability': ensemble_prob,
|
||||
'Ensemble_Weights': f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}'
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"预测过程出错: {str(e)}")
|
||||
|
||||
def predict_sequence(self, sequence, window_size=3, short_threshold=0.1, ensemble_weight=0.4):
|
||||
"""
|
||||
预测完整序列中的PRF位点(滑动窗口方法)
|
||||
|
||||
Args:
|
||||
sequence: 输入DNA序列
|
||||
window_size: 滑动窗口大小 (默认为3)
|
||||
short_threshold: short模型概率阈值 (默认为0.1)
|
||||
ensemble_weight: short模型在集成中的权重 (默认为0.4)
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: 包含预测结果的DataFrame
|
||||
"""
|
||||
if window_size < 1:
|
||||
raise ValueError("窗口大小必须大于等于1")
|
||||
if short_threshold < 0:
|
||||
raise ValueError("short模型阈值必须大于等于0")
|
||||
if not (0.0 <= ensemble_weight <= 1.0):
|
||||
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
|
||||
|
||||
results = []
|
||||
long_weight = 1.0 - ensemble_weight
|
||||
|
||||
try:
|
||||
# 确保序列为字符串并转换为大写
|
||||
sequence = str(sequence).upper()
|
||||
|
||||
# 滑动窗口预测
|
||||
for pos in range(0, len(sequence) - 2, window_size):
|
||||
# 提取窗口序列
|
||||
fs_period, full_seq = extract_window_sequences(sequence, pos)
|
||||
|
||||
if fs_period is None or full_seq is None:
|
||||
continue
|
||||
|
||||
# 预测并记录结果
|
||||
pred = self.predict_single_position(fs_period, full_seq, short_threshold, ensemble_weight)
|
||||
pred.update({
|
||||
'Position': pos,
|
||||
'Codon': sequence[pos:pos+3],
|
||||
'Short_Sequence': fs_period, # 更清晰的命名
|
||||
'Long_Sequence': full_seq # 更清晰的命名
|
||||
})
|
||||
results.append(pred)
|
||||
|
||||
# 创建结果DataFrame
|
||||
results_df = pd.DataFrame(results)
|
||||
|
||||
return results_df
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"序列预测过程出错: {str(e)}")
|
||||
|
||||
def plot_sequence_prediction(self, sequence, window_size=3, short_threshold=0.65,
|
||||
long_threshold=0.8, ensemble_weight=0.4, title=None, save_path=None,
|
||||
figsize=(12, 8), dpi=300):
|
||||
"""
|
||||
Plot sequence PRF prediction results
|
||||
|
||||
Args:
|
||||
sequence: Input DNA sequence
|
||||
window_size: Sliding window size (default: 3)
|
||||
short_threshold: Short model (HistGB) filtering threshold (default: 0.65)
|
||||
long_threshold: Long model (BiLSTM-CNN) filtering threshold (default: 0.8)
|
||||
ensemble_weight: Weight of short model in ensemble (default: 0.4)
|
||||
title: Plot title (optional)
|
||||
save_path: Save path (optional, saves plot if provided)
|
||||
figsize: Figure size (default: (12, 8))
|
||||
dpi: Figure resolution (default: 300)
|
||||
|
||||
Returns:
|
||||
tuple: (pd.DataFrame, matplotlib.figure.Figure) prediction results and figure object
|
||||
"""
|
||||
try:
|
||||
# Validate weight parameter
|
||||
if not (0.0 <= ensemble_weight <= 1.0):
|
||||
raise ValueError("ensemble_weight must be between 0.0 and 1.0")
|
||||
|
||||
long_weight = 1.0 - ensemble_weight
|
||||
|
||||
# Get prediction results
|
||||
results_df = self.predict_sequence(sequence, window_size=window_size,
|
||||
short_threshold=0.1, ensemble_weight=ensemble_weight)
|
||||
|
||||
if results_df.empty:
|
||||
raise ValueError("Prediction results are empty, please check input sequence")
|
||||
|
||||
# Get sequence length
|
||||
seq_length = len(sequence)
|
||||
|
||||
# Calculate display width
|
||||
desired_visual_width = max(3, seq_length // 100) # FS site width ~1% of sequence length
|
||||
prob_width = max(1, desired_visual_width // 3) # Prediction probability width is 1/3 of FS site width
|
||||
|
||||
# Create figure with three subplots, set height ratios
|
||||
fig = plt.figure(figsize=figsize)
|
||||
|
||||
# Set title
|
||||
if title:
|
||||
fig.suptitle(title, y=0.95, fontsize=10)
|
||||
else:
|
||||
fig.suptitle(f'PRF Prediction Results (Weights {ensemble_weight:.1f}:{long_weight:.1f})', y=0.95, fontsize=10)
|
||||
|
||||
# Adjust subplot ratios, make top two heatmaps smaller
|
||||
gs = fig.add_gridspec(3, 1, height_ratios=[0.1, 0.1, 1], hspace=0.2)
|
||||
|
||||
# FS site heatmap - using fixed width, no blur effect
|
||||
ax0 = fig.add_subplot(gs[0])
|
||||
fs_data = np.zeros((1, seq_length))
|
||||
# Note: No actual FS site information in sliding window prediction, so keep empty or show predicted sites
|
||||
# Show high-confidence predictions as potential FS sites
|
||||
for _, row in results_df.iterrows():
|
||||
pos = int(row['Position'])
|
||||
if (row['Short_Probability'] >= short_threshold and
|
||||
row['Long_Probability'] >= long_threshold and
|
||||
row['Ensemble_Probability'] >= 0.8): # High confidence threshold
|
||||
half_width = desired_visual_width // 2
|
||||
start_pos = max(0, pos - half_width)
|
||||
end_pos = min(seq_length, pos + half_width + 1)
|
||||
fs_data[0, start_pos:end_pos] = 1 # Use fixed value, no gradient
|
||||
|
||||
ax0.imshow(fs_data, cmap='Reds', aspect='auto', interpolation='nearest')
|
||||
ax0.set_xticks([])
|
||||
ax0.set_yticks([])
|
||||
ax0.set_title('FS site', pad=2, fontsize=8)
|
||||
|
||||
# Prediction probability heatmap - using fixed width to display probabilities
|
||||
ax1 = fig.add_subplot(gs[1])
|
||||
prob_data = np.zeros((1, seq_length))
|
||||
|
||||
# Apply dual threshold filtering
|
||||
for _, row in results_df.iterrows():
|
||||
pos = int(row['Position'])
|
||||
if (row['Short_Probability'] >= short_threshold and
|
||||
row['Long_Probability'] >= long_threshold):
|
||||
# Set fixed width for each probability value
|
||||
start = max(0, pos - prob_width//2)
|
||||
end = min(seq_length, pos + prob_width//2 + 1)
|
||||
prob_data[0, start:end] = row['Ensemble_Probability']
|
||||
|
||||
im = ax1.imshow(prob_data, cmap='Reds', aspect='auto', vmin=0, vmax=1, interpolation='nearest')
|
||||
ax1.set_xticks([])
|
||||
ax1.set_yticks([])
|
||||
ax1.set_title('Prediction', pad=2, fontsize=8)
|
||||
|
||||
# Main plot (bar chart)
|
||||
ax2 = fig.add_subplot(gs[2])
|
||||
|
||||
# Apply filtering thresholds
|
||||
filtered_probs = results_df['Ensemble_Probability'].copy()
|
||||
mask = ((results_df['Short_Probability'] < short_threshold) |
|
||||
(results_df['Long_Probability'] < long_threshold))
|
||||
filtered_probs[mask] = 0
|
||||
|
||||
# Draw bar chart - use black color and alpha=0.6 to match prediction_sample style
|
||||
ax2.bar(results_df['Position'], filtered_probs,
|
||||
alpha=0.6, color='black', width=1.0)
|
||||
|
||||
# Set x-axis ticks
|
||||
step = max(seq_length // 10, 50)
|
||||
ax2.set_xticks(np.arange(0, seq_length, step))
|
||||
ax2.tick_params(axis='x', rotation=45)
|
||||
|
||||
# Set labels
|
||||
ax2.set_xlabel('Position')
|
||||
ax2.set_ylabel('Probability')
|
||||
|
||||
# Set y-axis range
|
||||
ax2.set_ylim(0, 1)
|
||||
|
||||
# Add grid
|
||||
ax2.grid(True, alpha=0.3)
|
||||
|
||||
# Ensure all subplots have consistent x-axis range
|
||||
for ax in [ax0, ax1, ax2]:
|
||||
ax.set_xlim(-1, seq_length)
|
||||
|
||||
# Adjust layout
|
||||
plt.tight_layout()
|
||||
|
||||
# Save plot if save path is provided
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=dpi, bbox_inches='tight')
|
||||
# Also save PDF version
|
||||
if save_path.endswith('.png'):
|
||||
pdf_path = save_path.replace('.png', '.pdf')
|
||||
plt.savefig(pdf_path, bbox_inches='tight')
|
||||
print(f"Plot saved to: {save_path}")
|
||||
|
||||
return results_df, fig
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Error plotting sequence prediction: {str(e)}")
|
||||
|
||||
def predict_regions(self, sequences, short_threshold=0.1, ensemble_weight=0.4):
|
||||
'''
|
||||
Predict region sequences (batch prediction of known 399bp sequences)
|
||||
|
||||
Args:
|
||||
sequences: 399bp sequences or DataFrame/Series/list containing 399bp sequences
|
||||
short_threshold: Short model probability threshold (default: 0.1)
|
||||
ensemble_weight: Weight of short model in ensemble (default: 0.4)
|
||||
|
||||
Returns:
|
||||
DataFrame: DataFrame containing prediction probabilities for all sequences
|
||||
'''
|
||||
try:
|
||||
# Validate weight parameter
|
||||
if not (0.0 <= ensemble_weight <= 1.0):
|
||||
raise ValueError("ensemble_weight must be between 0.0 and 1.0")
|
||||
|
||||
# Unify input format
|
||||
if isinstance(sequences, (pd.DataFrame, pd.Series)):
|
||||
sequences = sequences.tolist()
|
||||
elif isinstance(sequences, str):
|
||||
sequences = [sequences]
|
||||
|
||||
results = []
|
||||
for i, seq399 in enumerate(sequences):
|
||||
try:
|
||||
# Extract central 33bp from 399bp sequence (for short model use)
|
||||
seq33 = self._extract_center_sequence(seq399, target_length=self.short_seq_length)
|
||||
|
||||
# Use unified prediction method
|
||||
pred_result = self.predict_single_position(seq33, seq399, short_threshold, ensemble_weight)
|
||||
pred_result.update({
|
||||
'Short_Sequence': seq33,
|
||||
'Long_Sequence': seq399
|
||||
})
|
||||
|
||||
results.append(pred_result)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing sequence {i+1}: {str(e)}")
|
||||
long_weight = 1.0 - ensemble_weight
|
||||
results.append({
|
||||
'Short_Probability': 0.0,
|
||||
'Long_Probability': 0.0,
|
||||
'Ensemble_Probability': 0.0,
|
||||
'Ensemble_Weights': f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}',
|
||||
'Short_Sequence': self._extract_center_sequence(seq399, target_length=self.short_seq_length) if len(seq399) >= self.short_seq_length else seq399,
|
||||
'Long_Sequence': seq399
|
||||
})
|
||||
|
||||
return pd.DataFrame(results)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Error in region prediction process: {str(e)}")
|
||||
|
||||
def _extract_center_sequence(self, sequence, target_length=33):
|
||||
"""Extract subsequence of specified length from center position of sequence"""
|
||||
# Ensure sequence is string
|
||||
sequence = str(sequence).upper()
|
||||
|
||||
# If sequence length is less than target length, return original sequence
|
||||
if len(sequence) <= target_length:
|
||||
return sequence
|
||||
|
||||
# Calculate center position
|
||||
center = len(sequence) // 2
|
||||
half_target = target_length // 2
|
||||
|
||||
# Extract center sequence
|
||||
start = center - half_target
|
||||
end = start + target_length
|
||||
|
||||
# Boundary check
|
||||
if start < 0:
|
||||
start = 0
|
||||
end = target_length
|
||||
elif end > len(sequence):
|
||||
end = len(sequence)
|
||||
start = end - target_length
|
||||
|
||||
return sequence[start:end]
|
||||
|
||||
# 兼容性方法(向后兼容,但标记为废弃)
|
||||
def predict_full(self, sequence, window_size=3, short_threshold=0.1, short_weight=0.4, plot=False):
|
||||
"""
|
||||
⚠️ 已废弃:请使用 predict_sequence() 方法
|
||||
|
||||
向后兼容的方法,内部调用新的 predict_sequence()
|
||||
"""
|
||||
import warnings
|
||||
warnings.warn("predict_full() 已废弃,请使用 predict_sequence() 方法", DeprecationWarning, stacklevel=2)
|
||||
|
||||
# 调用新方法并添加兼容性字段
|
||||
results_df = self.predict_sequence(sequence, window_size, short_threshold, short_weight)
|
||||
|
||||
# 添加兼容性字段
|
||||
if 'Ensemble_Probability' in results_df.columns:
|
||||
results_df['Voting_Probability'] = results_df['Ensemble_Probability']
|
||||
results_df['Weighted_Probability'] = results_df['Ensemble_Probability']
|
||||
if 'Ensemble_Weights' in results_df.columns:
|
||||
results_df['Weight_Info'] = results_df['Ensemble_Weights']
|
||||
if 'Short_Sequence' in results_df.columns:
|
||||
results_df['33bp'] = results_df['Short_Sequence']
|
||||
if 'Long_Sequence' in results_df.columns:
|
||||
results_df['399bp'] = results_df['Long_Sequence']
|
||||
|
||||
if plot:
|
||||
# 如果需要绘图,调用绘图方法
|
||||
_, fig = self.plot_sequence_prediction(sequence, window_size, 0.65, 0.8, short_weight)
|
||||
return results_df, fig
|
||||
|
||||
return results_df
|
||||
|
||||
def predict_region(self, seq, short_threshold=0.1, short_weight=0.4):
|
||||
"""
|
||||
⚠️ 已废弃:请使用 predict_regions() 方法
|
||||
|
||||
向后兼容的方法,内部调用新的 predict_regions()
|
||||
"""
|
||||
import warnings
|
||||
warnings.warn("predict_region() 已废弃,请使用 predict_regions() 方法", DeprecationWarning, stacklevel=2)
|
||||
|
||||
# 调用新方法并添加兼容性字段
|
||||
results_df = self.predict_regions(seq, short_threshold, short_weight)
|
||||
|
||||
# 添加兼容性字段
|
||||
if 'Ensemble_Probability' in results_df.columns:
|
||||
results_df['Voting_Probability'] = results_df['Ensemble_Probability']
|
||||
results_df['Weighted_Probability'] = results_df['Ensemble_Probability']
|
||||
if 'Ensemble_Weights' in results_df.columns:
|
||||
results_df['Weight_Info'] = results_df['Ensemble_Weights']
|
||||
if 'Short_Sequence' in results_df.columns:
|
||||
results_df['33bp'] = results_df['Short_Sequence']
|
||||
if 'Long_Sequence' in results_df.columns:
|
||||
results_df['399bp'] = results_df['Long_Sequence']
|
||||
|
||||
return results_df
|
||||
|
|
@ -6,65 +6,58 @@
|
|||
"source": [
|
||||
"# FScanpy \n",
|
||||
"\n",
|
||||
"这个 Notebook 展示了如何使用 FScanpy 的真实测试数据进行完整的 PRF 位点预测分析,包括:\n",
|
||||
"This notebook demonstrates how to use FScanpy with real test data for complete PRF site prediction analysis, including:\n",
|
||||
"\n",
|
||||
"## 🎯 完整工作流程\n",
|
||||
"1. **加载测试数据** - 使用内置的真实测试数据\n",
|
||||
"2. **FScanR 分析** - 从 BLASTX 结果识别潜在 PRF 位点\n",
|
||||
"3. **序列提取** - 提取 PRF 位点周围的序列\n",
|
||||
"4. **FScanpy 预测** - 使用机器学习模型预测概率\n",
|
||||
"5. **结果可视化** - 使用内置绘图函数生成预测结果图表\n",
|
||||
"6. **序列级预测演示** - 完整序列的滑动窗口分析\n",
|
||||
"## 🎯 Complete Workflow\n",
|
||||
"1. **Load Test Data** - Use built-in real test data\n",
|
||||
"2. **FScanR Analysis** - Identify potential PRF sites from BLASTX results\n",
|
||||
"3. **Sequence Extraction** - Extract sequences around PRF sites\n",
|
||||
"4. **FScanpy Prediction** - Use machine learning models to predict probabilities\n",
|
||||
"5. **Results Visualization** - Generate prediction result plots using built-in plotting functions\n",
|
||||
"6. **Sequence-level Prediction Demo** - Sliding window analysis of complete sequences\n",
|
||||
"\n",
|
||||
"## 📊 数据说明\n",
|
||||
"- **blastx_example.xlsx**: 真实BLASTX比对结果\n",
|
||||
"- **mrna_example.fasta**: 真实mRNA序列数据\n",
|
||||
"- **region_example.csv**: 单独对某个位点进行预测的样本"
|
||||
"## 📊 Data Description\n",
|
||||
"- **blastx_example.xlsx**: Real BLASTX alignment results\n",
|
||||
"- **mrna_example.fasta**: Real mRNA sequence data\n",
|
||||
"- **region_example.csv**: Sample for individual site prediction"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 📦 环境准备和数据加载"
|
||||
"## 📦 Environment Setup and Data Loading"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"✅ 环境准备完成!\n",
|
||||
"📋 可用的测试数据:\n"
|
||||
"ename": "ImportError",
|
||||
"evalue": "cannot import name 'PRFPredictor' from 'FScanpy' (unknown location)",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[3], line 6\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpyplot\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mplt\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;66;03m# Import FScanpy related modules\u001b[39;00m\n\u001b[0;32m----> 6\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mFScanpy\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m PRFPredictor, predict_prf, plot_prf_prediction\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mFScanpy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdata\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m get_test_data_path, list_test_data\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mFScanpy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m fscanr, extract_prf_regions\n",
|
||||
"\u001b[0;31mImportError\u001b[0m: cannot import name 'PRFPredictor' from 'FScanpy' (unknown location)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['blastx_example.xlsx', 'mrna_example.fasta', 'region_example.csv']"
|
||||
]
|
||||
},
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 导入必要的库\n",
|
||||
"# Import necessary libraries\n",
|
||||
"import pandas as pd\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"# 导入FScanpy相关模块\n",
|
||||
"# Import FScanpy related modules\n",
|
||||
"from FScanpy import PRFPredictor, predict_prf, plot_prf_prediction\n",
|
||||
"from FScanpy.data import get_test_data_path, list_test_data\n",
|
||||
"from FScanpy.utils import fscanr, extract_prf_regions\n",
|
||||
"\n",
|
||||
"print(\"✅ 环境准备完成!\")\n",
|
||||
"print(\"📋 可用的测试数据:\")\n",
|
||||
"print(\"✅ Environment setup complete!\")\n",
|
||||
"print(\"📋 Available test data:\")\n",
|
||||
"list_test_data()"
|
||||
]
|
||||
},
|
||||
|
|
@ -72,9 +65,9 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. 加载和探索测试数据\n",
|
||||
"## 1. Load and Explore Test Data\n",
|
||||
"\n",
|
||||
"首先加载 FScanpy 提供的真实测试数据,了解数据结构。"
|
||||
"First, load the real test data provided by FScanpy to understand the data structure."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -107,25 +100,25 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"# 获取测试数据路径\n",
|
||||
"# Get test data paths\n",
|
||||
"blastx_file = get_test_data_path('blastx_example.xlsx')\n",
|
||||
"mrna_file = get_test_data_path('mrna_example.fasta')\n",
|
||||
"region_file = get_test_data_path('region_example.csv')\n",
|
||||
"\n",
|
||||
"print(f\"📁 数据文件路径:\")\n",
|
||||
"print(f\" BLASTX数据: {blastx_file}\")\n",
|
||||
"print(f\" mRNA序列: {mrna_file}\")\n",
|
||||
"print(f\" 验证区域: {region_file}\")\n",
|
||||
"print(f\"📁 Data file paths:\")\n",
|
||||
"print(f\" BLASTX data: {blastx_file}\")\n",
|
||||
"print(f\" mRNA sequences: {mrna_file}\")\n",
|
||||
"print(f\" Validation regions: {region_file}\")\n",
|
||||
"\n",
|
||||
"# 加载BLASTX数据\n",
|
||||
"# Load BLASTX data\n",
|
||||
"blastx_data = pd.read_excel(blastx_file)\n",
|
||||
"print(f\"\\n🧬 BLASTX数据概览:\")\n",
|
||||
"print(f\" 数据形状: {blastx_data.shape}\")\n",
|
||||
"print(f\" 列名: {list(blastx_data.columns)}\")\n",
|
||||
"print(f\" 唯一序列数: {blastx_data['DNA_seqid'].nunique()}\")\n",
|
||||
"print(f\"\\n🧬 BLASTX data overview:\")\n",
|
||||
"print(f\" Data shape: {blastx_data.shape}\")\n",
|
||||
"print(f\" Column names: {list(blastx_data.columns)}\")\n",
|
||||
"print(f\" Unique sequences: {blastx_data['DNA_seqid'].nunique()}\")\n",
|
||||
"\n",
|
||||
"# 显示前几行\n",
|
||||
"print(\"\\n📊 BLASTX数据示例:\")\n",
|
||||
"# Display first few rows\n",
|
||||
"print(\"\\n📊 BLASTX data examples:\")\n",
|
||||
"display_cols = ['DNA_seqid', 'Pep_seqid', 'pident', 'length', 'evalue', 'qframe']\n",
|
||||
"print(blastx_data[display_cols].head())"
|
||||
]
|
||||
|
|
@ -163,21 +156,21 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"# 加载验证区域数据\n",
|
||||
"# Load validation region data\n",
|
||||
"region_data = pd.read_csv(region_file)\n",
|
||||
"print(f\"🎯 验证区域数据概览:\")\n",
|
||||
"print(f\" 数据形状: {region_data.shape}\")\n",
|
||||
"print(f\" 列名: {list(region_data.columns)}\")\n",
|
||||
"print(f\" 数据来源: {region_data['source'].value_counts().to_dict()}\")\n",
|
||||
"print(f\"🎯 Validation region data overview:\")\n",
|
||||
"print(f\" Data shape: {region_data.shape}\")\n",
|
||||
"print(f\" Column names: {list(region_data.columns)}\")\n",
|
||||
"print(f\" Data sources: {region_data['source'].value_counts().to_dict()}\")\n",
|
||||
"\n",
|
||||
"print(\"\\n📋 验证区域数据示例:\")\n",
|
||||
"print(\"\\n📋 Validation region data examples:\")\n",
|
||||
"display_cols = ['fs_position', 'DNA_seqid', 'label', 'source', 'FS_type']\n",
|
||||
"print(region_data[display_cols].head())\n",
|
||||
"\n",
|
||||
"# 统计分析\n",
|
||||
"print(f\"\\n📈 标签分布:\")\n",
|
||||
"# Statistical analysis\n",
|
||||
"print(f\"\\n📈 Label distribution:\")\n",
|
||||
"print(region_data['label'].value_counts())\n",
|
||||
"print(f\"\\n🔬 FS类型分布:\")\n",
|
||||
"print(f\"\\n🔬 FS type distribution:\")\n",
|
||||
"print(region_data['FS_type'].value_counts())"
|
||||
]
|
||||
},
|
||||
|
|
@ -185,9 +178,9 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2. FScanR 分析 - 从 BLASTX 识别潜在 PRF 位点\n",
|
||||
"## 2. FScanR Analysis - Identify Potential PRF Sites from BLASTX\n",
|
||||
"\n",
|
||||
"使用 FScanR 算法分析 BLASTX 结果,识别潜在的程序性核糖体移码位点。"
|
||||
"Use the FScanR algorithm to analyze BLASTX results and identify potential programmed ribosomal frameshift sites."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -229,9 +222,9 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"# 运行FScanR分析\n",
|
||||
"print(\"🔍 运行FScanR分析...\")\n",
|
||||
"print(\"参数设置: mismatch_cutoff=10, evalue_cutoff=1e-5, frameDist_cutoff=10\")\n",
|
||||
"# Run FScanR analysis\n",
|
||||
"print(\"🔍 Running FScanR analysis...\")\n",
|
||||
"print(\"Parameter settings: mismatch_cutoff=10, evalue_cutoff=1e-5, frameDist_cutoff=100\")\n",
|
||||
"\n",
|
||||
"fscanr_results = fscanr(\n",
|
||||
" blastx_data,\n",
|
||||
|
|
@ -240,29 +233,29 @@
|
|||
" frameDist_cutoff=100\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(f\"\\n✅ FScanR分析完成!\")\n",
|
||||
"print(f\"检测到的潜在PRF位点数量: {len(fscanr_results)}\")\n",
|
||||
"print(f\"\\n✅ FScanR analysis complete!\")\n",
|
||||
"print(f\"Number of potential PRF sites detected: {len(fscanr_results)}\")\n",
|
||||
"\n",
|
||||
"if len(fscanr_results) > 0:\n",
|
||||
" print(f\"\\n📊 FScanR结果概览:\")\n",
|
||||
" print(f\" 列名: {list(fscanr_results.columns)}\")\n",
|
||||
" print(f\" 涉及的序列数: {fscanr_results['DNA_seqid'].nunique()}\")\n",
|
||||
" print(f\" 链方向分布: {fscanr_results['Strand'].value_counts().to_dict()}\")\n",
|
||||
" print(f\" FS类型分布: {fscanr_results['FS_type'].value_counts().to_dict()}\")\n",
|
||||
" print(f\"\\n📊 FScanR results overview:\")\n",
|
||||
" print(f\" Column names: {list(fscanr_results.columns)}\")\n",
|
||||
" print(f\" Number of sequences involved: {fscanr_results['DNA_seqid'].nunique()}\")\n",
|
||||
" print(f\" Strand orientation distribution: {fscanr_results['Strand'].value_counts().to_dict()}\")\n",
|
||||
" print(f\" FS type distribution: {fscanr_results['FS_type'].value_counts().to_dict()}\")\n",
|
||||
" \n",
|
||||
" print(\"\\n🎯 FScanR结果示例:\")\n",
|
||||
" print(\"\\n🎯 FScanR results examples:\")\n",
|
||||
" print(fscanr_results.head())\n",
|
||||
"else:\n",
|
||||
" print(\"⚠️ 未检测到PRF位点,可能需要调整参数\")"
|
||||
" print(\"⚠️ No PRF sites detected, may need to adjust parameters\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 3. 序列提取 - 获取 PRF 位点周围序列\n",
|
||||
"## 3. Sequence Extraction - Extract Sequences Around PRF Sites\n",
|
||||
"\n",
|
||||
"从 mRNA 序列中提取 FScanR 识别的 PRF 位点周围的序列片段。"
|
||||
"Extract sequence fragments around PRF sites identified by FScanR from mRNA sequences."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -306,36 +299,36 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"# 提取PRF位点周围的序列\n",
|
||||
"# Extract sequences around PRF sites\n",
|
||||
"if len(fscanr_results) > 0:\n",
|
||||
" print(\"📝 从mRNA序列中提取PRF位点周围序列...\")\n",
|
||||
" print(\"📝 Extracting sequences around PRF sites from mRNA sequences...\")\n",
|
||||
" \n",
|
||||
" prf_sequences = extract_prf_regions(\n",
|
||||
" mrna_file=mrna_file,\n",
|
||||
" prf_data=fscanr_results\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" print(f\"\\n✅ 序列提取完成!\")\n",
|
||||
" print(f\"成功提取的序列数量: {len(prf_sequences)}\")\n",
|
||||
" print(f\"\\n✅ Sequence extraction complete!\")\n",
|
||||
" print(f\"Number of successfully extracted sequences: {len(prf_sequences)}\")\n",
|
||||
" \n",
|
||||
" if len(prf_sequences) > 0:\n",
|
||||
" print(f\"\\n📏 序列长度验证:\")\n",
|
||||
" print(f\"\\n📏 Sequence length validation:\")\n",
|
||||
" seq_lengths = prf_sequences['399bp'].str.len()\n",
|
||||
" print(f\" 399bp序列长度分布: {seq_lengths.value_counts().to_dict()}\")\n",
|
||||
" print(f\" 平均长度: {seq_lengths.mean():.1f}\")\n",
|
||||
" print(f\" 399bp sequence length distribution: {seq_lengths.value_counts().to_dict()}\")\n",
|
||||
" print(f\" Average length: {seq_lengths.mean():.1f}\")\n",
|
||||
" \n",
|
||||
" print(\"\\n🧬 提取序列示例:\")\n",
|
||||
" print(\"\\n🧬 Extracted sequence examples:\")\n",
|
||||
" for i, row in prf_sequences.head(3).iterrows():\n",
|
||||
" print(f\"序列 {i+1}: {row['DNA_seqid']}\")\n",
|
||||
" print(f\" FS位置: {row['FS_start']}-{row['FS_end']}\")\n",
|
||||
" print(f\" 链方向: {row['Strand']}\")\n",
|
||||
" print(f\" FS类型: {row['FS_type']}\")\n",
|
||||
" print(f\" 序列片段: {row['399bp'][:50]}...{row['399bp'][-20:]}\")\n",
|
||||
" print(f\"Sequence {i+1}: {row['DNA_seqid']}\")\n",
|
||||
" print(f\" FS position: {row['FS_start']}-{row['FS_end']}\")\n",
|
||||
" print(f\" Strand orientation: {row['Strand']}\")\n",
|
||||
" print(f\" FS type: {row['FS_type']}\")\n",
|
||||
" print(f\" Sequence fragment: {row['399bp'][:50]}...{row['399bp'][-20:]}\")\n",
|
||||
" print()\n",
|
||||
" else:\n",
|
||||
" print(\"❌ 序列提取失败\")\n",
|
||||
" print(\"❌ Sequence extraction failed\")\n",
|
||||
"else:\n",
|
||||
" print(\"⚠️ 跳过序列提取 - 无FScanR结果\")\n",
|
||||
" print(\"⚠️ Skipping sequence extraction - no FScanR results\")\n",
|
||||
" prf_sequences = pd.DataFrame()"
|
||||
]
|
||||
},
|
||||
|
|
@ -343,9 +336,9 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 4. FScanpy 预测 - 机器学习模型分析\n",
|
||||
"## 4. FScanpy Prediction - Machine Learning Model Analysis\n",
|
||||
"\n",
|
||||
"使用 FScanpy 的机器学习模型对提取的序列进行 PRF 概率预测。"
|
||||
"Use FScanpy's machine learning models to predict PRF probabilities for the extracted sequences."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -379,26 +372,26 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"# 初始化预测器\n",
|
||||
"# Initialize predictor\n",
|
||||
"predictor = PRFPredictor()\n",
|
||||
"print(\"🤖 FScanpy预测器初始化完成\")\n",
|
||||
"print(\"🤖 FScanpy predictor initialization complete\")\n",
|
||||
"\n",
|
||||
"# 对FScanR识别的序列进行预测\n",
|
||||
"# Predict FScanR identified sequences\n",
|
||||
"if len(prf_sequences) > 0:\n",
|
||||
" print(f\"\\n🎯 对 {len(prf_sequences)} 个FScanR识别的序列进行预测...\")\n",
|
||||
" print(f\"\\n🎯 Predicting {len(prf_sequences)} sequences identified by FScanR...\")\n",
|
||||
" \n",
|
||||
" fscanr_predictions = predictor.predict_regions(\n",
|
||||
" sequences=prf_sequences['399bp'],\n",
|
||||
" ensemble_weight=0.4 # 平衡配置\n",
|
||||
" ensemble_weight=0.4 # Balanced configuration\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" # 合并结果\n",
|
||||
" # Merge results\n",
|
||||
" fscanr_predictions = pd.concat([\n",
|
||||
" prf_sequences.reset_index(drop=True),\n",
|
||||
" fscanr_predictions.reset_index(drop=True)\n",
|
||||
" ], axis=1)\n",
|
||||
" \n",
|
||||
" print(\"\\n📊 FScanR+FScanpy预测结果:\")\n",
|
||||
" print(\"\\n📊 FScanR+FScanpy prediction results:\")\n",
|
||||
" result_cols = ['DNA_seqid', 'FS_start', 'FS_type', 'Short_Probability', 'Long_Probability', 'Ensemble_Probability']\n",
|
||||
" print(fscanr_predictions[result_cols].head())"
|
||||
]
|
||||
|
|
@ -429,15 +422,15 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"# 对验证区域数据进行预测\n",
|
||||
"print(f\"\\n🧪 对 {len(region_data)} 个验证区域进行预测...\")\n",
|
||||
"# Predict validation region data\n",
|
||||
"print(f\"\\n🧪 Predicting {len(region_data)} validation regions...\")\n",
|
||||
"\n",
|
||||
"validation_predictions = predict_prf(\n",
|
||||
" data=region_data.rename(columns={'399bp': 'Long_Sequence'}),\n",
|
||||
" ensemble_weight=0.4\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(\"\\n📊 验证区域预测结果:\")\n",
|
||||
"print(\"\\n📊 Validation region prediction results:\")\n",
|
||||
"result_cols = ['DNA_seqid', 'label', 'source', 'Short_Probability', 'Long_Probability', 'Ensemble_Probability']\n",
|
||||
"print(validation_predictions[result_cols].head())"
|
||||
]
|
||||
|
|
@ -446,9 +439,9 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 5. 序列级预测和可视化\n",
|
||||
"## 5. Sequence-level Prediction and Visualization\n",
|
||||
"\n",
|
||||
"选择一个具体的mRNA序列,使用内置的plot_prf_prediction函数进行完整的滑动窗口预测和可视化。"
|
||||
"Select a specific mRNA sequence and use the built-in plot_prf_prediction function for complete sliding window prediction and visualization."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -557,19 +550,19 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"# 选择一个序列进行演示\n",
|
||||
"# Select a sequence for demonstration\n",
|
||||
"from Bio import SeqIO\n",
|
||||
"\n",
|
||||
"# 读取第一个mRNA序列作为演示\n",
|
||||
"# Read the first mRNA sequence for demonstration\n",
|
||||
"mrna_sequences = list(SeqIO.parse(mrna_file, \"fasta\"))\n",
|
||||
"demo_seq = mrna_sequences[0] # 选择第一个序列\n",
|
||||
"demo_seq = mrna_sequences[0] # Select the first sequence\n",
|
||||
"\n",
|
||||
"print(f\"🧬 选择演示序列: {demo_seq.id}\")\n",
|
||||
"print(f\"序列长度: {len(demo_seq.seq)} bp\")\n",
|
||||
"print(f\"序列前100bp: {str(demo_seq.seq)[:100]}...\")\n",
|
||||
"print(f\"🧬 Selected demonstration sequence: {demo_seq.id}\")\n",
|
||||
"print(f\"Sequence length: {len(demo_seq.seq)} bp\")\n",
|
||||
"print(f\"First 100bp of sequence: {str(demo_seq.seq)[:100]}...\")\n",
|
||||
"\n",
|
||||
"# 使用内置的plot_prf_prediction函数进行预测和可视化\n",
|
||||
"print(f\"\\n🎯 使用plot_prf_prediction进行序列预测和可视化...\")\n",
|
||||
"# Use built-in plot_prf_prediction function for prediction and visualization\n",
|
||||
"print(f\"\\n🎯 Using plot_prf_prediction for sequence prediction and visualization...\")\n",
|
||||
"\n",
|
||||
"sequence_results, fig = plot_prf_prediction(\n",
|
||||
" sequence=str(demo_seq.seq),\n",
|
||||
|
|
@ -577,18 +570,18 @@
|
|||
" short_threshold=0.2,\n",
|
||||
" long_threshold=0.2,\n",
|
||||
" ensemble_weight=0.6,\n",
|
||||
" title=f\"序列 {demo_seq.id} 的PRF预测结果(条形图+热图)\",\n",
|
||||
" title=f\"PRF Prediction Results for Sequence {demo_seq.id} (Bar Chart + Heatmap)\",\n",
|
||||
" figsize=(16, 8),\n",
|
||||
" dpi=150\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"plt.show()\n",
|
||||
"\n",
|
||||
"print(f\"\\n📊 序列预测结果统计:\")\n",
|
||||
"print(f\" 预测位点总数: {len(sequence_results)}\")\n",
|
||||
"print(f\" 高概率位点 (>0.8): {(sequence_results['Ensemble_Probability'] > 0.8).sum()}\")\n",
|
||||
"print(f\" 中概率位点 (0.4-0.8): {((sequence_results['Ensemble_Probability'] >= 0.4) & (sequence_results['Ensemble_Probability'] <= 0.8)).sum()}\")\n",
|
||||
"print(f\" 最高预测概率: {sequence_results['Ensemble_Probability'].max():.3f}\")"
|
||||
"print(f\"\\n📊 Sequence prediction result statistics:\")\n",
|
||||
"print(f\" Total predicted sites: {len(sequence_results)}\")\n",
|
||||
"print(f\" High probability sites (>0.8): {(sequence_results['Ensemble_Probability'] > 0.8).sum()}\")\n",
|
||||
"print(f\" Medium probability sites (0.4-0.8): {((sequence_results['Ensemble_Probability'] >= 0.4) & (sequence_results['Ensemble_Probability'] <= 0.8)).sum()}\")\n",
|
||||
"print(f\" Highest prediction probability: {sequence_results['Ensemble_Probability'].max():.3f}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -634,53 +627,53 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"# 打印Top预测位点的概率\n",
|
||||
"# Print top predicted site probabilities\n",
|
||||
"if sequence_results['Ensemble_Probability'].max() > 0.3:\n",
|
||||
" top_predictions = sequence_results.nlargest(5, 'Ensemble_Probability')\n",
|
||||
" print(f\"\\n🔝 Top 5 预测位点:\")\n",
|
||||
" print(f\"\\n🔝 Top 5 predicted sites:\")\n",
|
||||
" for i, (_, row) in enumerate(top_predictions.iterrows(), 1):\n",
|
||||
" print(f\" {i}. 位置 {row['Position']}: \")\n",
|
||||
" print(f\" - Short概率: {row['Short_Probability']:.3f}\")\n",
|
||||
" print(f\" - Long概率: {row['Long_Probability']:.3f}\")\n",
|
||||
" print(f\" - 集成概率: {row['Ensemble_Probability']:.3f}\")\n",
|
||||
" print(f\" - 密码子: {row['Codon']}\")\n",
|
||||
" print(f\" {i}. Position {row['Position']}: \")\n",
|
||||
" print(f\" - Short probability: {row['Short_Probability']:.3f}\")\n",
|
||||
" print(f\" - Long probability: {row['Long_Probability']:.3f}\")\n",
|
||||
" print(f\" - Ensemble probability: {row['Ensemble_Probability']:.3f}\")\n",
|
||||
" print(f\" - Codon: {row['Codon']}\")\n",
|
||||
"else:\n",
|
||||
" print(\"\\n💡 该序列没有检测到高概率的PRF位点\")\n",
|
||||
" print(\"\\n💡 No high-probability PRF sites detected in this sequence\")\n",
|
||||
"\n",
|
||||
"print(\"\\n📊 可视化分析完成!\")\n",
|
||||
"print(\"图表包含热图和条形图,展示了整个序列的PRF预测概率分布。\")"
|
||||
"print(\"\\n📊 Visualization analysis complete!\")\n",
|
||||
"print(\"The chart contains heatmaps and bar charts showing the PRF prediction probability distribution across the entire sequence.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 📝 分析总结\n",
|
||||
"## 📝 Analysis Summary\n",
|
||||
"\n",
|
||||
"### 🎯 主要发现\n",
|
||||
"1. **数据质量**: 测试数据集包含真实的BLASTX比对结果和验证区域\n",
|
||||
"2. **FScanR效果**: 从BLASTX结果中识别出潜在PRF位点\n",
|
||||
"3. **模型性能**: Short和Long模型在不同场景下各有优势\n",
|
||||
"4. **预测结果**: 集成模型提供了更稳定的预测性能\n",
|
||||
"5. **可视化**: 内置绘图函数生成清晰的热图和条形图\n",
|
||||
"### 🎯 Key Findings\n",
|
||||
"1. **Data Quality**: Test dataset contains real BLASTX alignment results and validation regions\n",
|
||||
"2. **FScanR Performance**: Successfully identified potential PRF sites from BLASTX results\n",
|
||||
"3. **Model Performance**: Short and Long models each have advantages in different scenarios\n",
|
||||
"4. **Prediction Results**: Ensemble model provides more stable prediction performance\n",
|
||||
"5. **Visualization**: Built-in plotting functions generate clear heatmaps and bar charts\n",
|
||||
"\n",
|
||||
"### 🔧 最佳实践\n",
|
||||
"- **数据预处理**: 确保BLASTX结果格式正确\n",
|
||||
"- **参数设置**: 使用默认的集成权重(0.4:0.6)获得平衡性能\n",
|
||||
"- **结果解读**: 在使用FScanpy对整条序列进行预测时,不应该使用0.5作为阈值,而应该比较不同位置的概率高低\n",
|
||||
"- **可视化**: 使用plot_prf_prediction函数生成标准化图表\n",
|
||||
"### 🔧 Best Practices\n",
|
||||
"- **Data Preprocessing**: Ensure BLASTX results are in correct format\n",
|
||||
"- **Parameter Settings**: Use default ensemble weights (0.4:0.6) for balanced performance\n",
|
||||
"- **Result Interpretation**: When using FScanpy for whole sequence prediction, don't use 0.5 as threshold, but compare relative probabilities across positions\n",
|
||||
"- **Visualization**: Use plot_prf_prediction function to generate standardized plots\n",
|
||||
"\n",
|
||||
"### 📚 使用建议\n",
|
||||
"1. **阈值选择**: 根据应用场景调整概率阈值\n",
|
||||
"2. **结果验证**: 结合生物学知识验证预测结果\n",
|
||||
"3. **性能优化**: 对于大规模数据使用合理的滑动窗口大小\n",
|
||||
"4. **可视化参数**: 调整figsize和dpi获得最佳显示效果"
|
||||
"### 📚 Usage Recommendations\n",
|
||||
"1. **Threshold Selection**: Adjust probability thresholds based on application scenarios\n",
|
||||
"2. **Result Validation**: Validate prediction results with biological knowledge\n",
|
||||
"3. **Performance Optimization**: Use reasonable sliding window sizes for large-scale data\n",
|
||||
"4. **Visualization Parameters**: Adjust figsize and dpi for optimal display"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "tf200",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
|
@ -694,7 +687,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.5"
|
||||
"version": "3.9.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
|
|||
Loading…
Reference in New Issue