优化绘图,增加使用介绍的ipynb

This commit is contained in:
Chenlab 2025-06-11 21:18:52 +08:00
parent 3833704c47
commit 089df9c4a6
3 changed files with 692 additions and 687 deletions

View File

@ -19,61 +19,61 @@ def predict_prf(
model_dir: str = None
) -> pd.DataFrame:
"""
PRF位点预测函数
PRF site prediction function
Args:
sequence: 单个或多个DNA序列用于滑动窗口预测
data: DataFrame数据必须包含'Long_Sequence''399bp'用于区域预测
window_size: 滑动窗口大小默认为3
short_threshold: Short模型(HistGB)概率阈值默认为0.1
ensemble_weight: Short模型在集成中的权重默认为0.4Long权重为0.6
model_dir: 模型文件目录路径可选
sequence: Single or multiple DNA sequences for sliding window prediction
data: DataFrame data, must contain 'Long_Sequence' or '399bp' column for region prediction
window_size: Sliding window size (default: 3)
short_threshold: Short model (HistGB) probability threshold (default: 0.1)
ensemble_weight: Weight of short model in ensemble (default: 0.4, long weight: 0.6)
model_dir: Model directory path (optional)
Returns:
pandas.DataFrame: 预测结果包含以下主要字段
- Short_Probability: Short模型预测概率
- Long_Probability: Long模型预测概率
- Ensemble_Probability: 集成预测概率主要结果
- Ensemble_Weights: 权重配置信息
pandas.DataFrame: Prediction results containing the following main fields:
- Short_Probability: Short model prediction probability
- Long_Probability: Long model prediction probability
- Ensemble_Probability: Ensemble prediction probability (main result)
- Ensemble_Weights: Weight configuration information
Examples:
# 1. 单条序列滑动窗口预测
# 1. Single sequence sliding window prediction
>>> from FScanpy import predict_prf
>>> sequence = "ATGCGTACGT..."
>>> results = predict_prf(sequence=sequence)
# 2. 多条序列滑动窗口预测
# 2. Multiple sequences sliding window prediction
>>> sequences = ["ATGCGTACGT...", "GCTATAGCAT..."]
>>> results = predict_prf(sequence=sequences)
# 3. 自定义集成权重比例
>>> results = predict_prf(sequence=sequence, ensemble_weight=0.3) # 3:7 比例
# 3. Custom ensemble weight ratio
>>> results = predict_prf(sequence=sequence, ensemble_weight=0.3) # 3:7 ratio
# 4. DataFrame区域预测
# 4. DataFrame region prediction
>>> import pandas as pd
>>> data = pd.DataFrame({
... 'Long_Sequence': ['ATGCGT...', 'GCTATAG...'] # 或使用 '399bp'
... 'Long_Sequence': ['ATGCGT...', 'GCTATAG...'] # or use '399bp'
... })
>>> results = predict_prf(data=data)
"""
predictor = PRFPredictor(model_dir=model_dir)
# 验证输入参数
# Validate input parameters
if sequence is None and data is None:
raise ValueError("必须提供sequence或data参数之一")
raise ValueError("Must provide either sequence or data parameter")
if sequence is not None and data is not None:
raise ValueError("sequence和data参数不能同时提供")
raise ValueError("Cannot provide both sequence and data parameters")
if not (0.0 <= ensemble_weight <= 1.0):
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
raise ValueError("ensemble_weight must be between 0.0 and 1.0")
# 滑动窗口预测模式
# Sliding window prediction mode
if sequence is not None:
if isinstance(sequence, str):
# 单条序列预测
# Single sequence prediction
return predictor.predict_sequence(
sequence, window_size, short_threshold, ensemble_weight)
elif isinstance(sequence, (list, tuple)):
# 多条序列预测
# Multiple sequences prediction
results = []
for i, seq in enumerate(sequence, 1):
try:
@ -82,29 +82,29 @@ def predict_prf(
result['Sequence_ID'] = f'seq_{i}'
results.append(result)
except Exception as e:
print(f"警告:序列 {i} 预测失败 - {str(e)}")
print(f"Warning: Sequence {i} prediction failed - {str(e)}")
return pd.concat(results, ignore_index=True) if results else pd.DataFrame()
# 区域化预测模式
# Region prediction mode
else:
if not isinstance(data, pd.DataFrame):
raise ValueError("data参数必须是pandas DataFrame类型")
raise ValueError("data parameter must be pandas DataFrame type")
# 检查列名(支持新旧两种命名)
# Check column names (support both new and old naming conventions)
seq_column = None
if 'Long_Sequence' in data.columns:
seq_column = 'Long_Sequence'
elif '399bp' in data.columns:
seq_column = '399bp'
else:
raise ValueError("DataFrame必须包含'Long_Sequence''399bp'")
raise ValueError("DataFrame must contain 'Long_Sequence' or '399bp' column")
# 调用区域预测函数
# Call region prediction function
try:
results = predictor.predict_regions(
data[seq_column], short_threshold, ensemble_weight)
# 添加原始数据的其他列
# Add other columns from original data
for col in data.columns:
if col not in ['Long_Sequence', '399bp', 'Short_Sequence', '33bp']:
results[col] = data[col].values
@ -112,8 +112,8 @@ def predict_prf(
return results
except Exception as e:
print(f"警告:区域预测失败 - {str(e)}")
# 创建空结果
print(f"Warning: Region prediction failed - {str(e)}")
# Create empty results
long_weight = 1.0 - ensemble_weight
results = pd.DataFrame({
'Short_Probability': [0.0] * len(data),
@ -122,7 +122,7 @@ def predict_prf(
'Ensemble_Weights': [f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}'] * len(data)
})
# 添加原始数据列
# Add original data columns
for col in data.columns:
results[col] = data[col].values
@ -136,59 +136,59 @@ def plot_prf_prediction(
ensemble_weight: float = 0.4,
title: str = None,
save_path: str = None,
figsize: tuple = (12, 6),
figsize: tuple = (12, 8),
dpi: int = 300,
model_dir: str = None
) -> tuple:
"""
绘制序列PRF预测结果的移码概率图
Plot PRF prediction results for sequence frameshifting probability
Args:
sequence: 输入DNA序列
window_size: 滑动窗口大小默认为3
short_threshold: Short模型(HistGB)过滤阈值默认为0.65
long_threshold: Long模型(BiLSTM-CNN)过滤阈值默认为0.8
ensemble_weight: Short模型在集成中的权重默认为0.4Long权重为0.6
title: 图片标题可选
save_path: 保存路径可选如果提供则保存图片
figsize: 图片尺寸默认为(12, 6)
dpi: 图片分辨率默认为300
model_dir: 模型文件目录路径可选
sequence: Input DNA sequence
window_size: Sliding window size (default: 3)
short_threshold: Short model (HistGB) filtering threshold (default: 0.65)
long_threshold: Long model (BiLSTM-CNN) filtering threshold (default: 0.8)
ensemble_weight: Weight of short model in ensemble (default: 0.4, long weight: 0.6)
title: Plot title (optional)
save_path: Save path (optional, saves plot if provided)
figsize: Figure size (default: (12, 8))
dpi: Figure resolution (default: 300)
model_dir: Model directory path (optional)
Returns:
tuple: (pd.DataFrame, matplotlib.figure.Figure) 预测结果和图形对象
tuple: (pd.DataFrame, matplotlib.figure.Figure) prediction results and figure object
Examples:
# 1. 简单绘图
# 1. Simple plotting
>>> from FScanpy import plot_prf_prediction
>>> sequence = "ATGCGTACGT..."
>>> results, fig = plot_prf_prediction(sequence)
>>> plt.show()
# 2. 自定义阈值和集成权重
# 2. Custom thresholds and ensemble weights
>>> results, fig = plot_prf_prediction(
... sequence,
... short_threshold=0.7,
... long_threshold=0.85,
... ensemble_weight=0.3, # 3:7 权重比例
... title="自定义权重的预测结果",
... ensemble_weight=0.3, # 3:7 weight ratio
... title="Custom Weight Prediction Results",
... save_path="prediction_result.png"
... )
# 3. 等权重组合
# 3. Equal weight combination
>>> results, fig = plot_prf_prediction(
... sequence,
... ensemble_weight=0.5 # 5:5 等权重
... ensemble_weight=0.5 # 5:5 equal weights
... )
# 4. Long模型主导
# 4. Long model dominated
>>> results, fig = plot_prf_prediction(
... sequence,
... ensemble_weight=0.2 # 2:8 权重Long模型主导
... ensemble_weight=0.2 # 2:8 weights, long model dominated
... )
"""
if not (0.0 <= ensemble_weight <= 1.0):
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
raise ValueError("ensemble_weight must be between 0.0 and 1.0")
predictor = PRFPredictor(model_dir=model_dir)

View File

@ -1,487 +1,499 @@
import os
import pickle
import numpy as np
import pandas as pd
from tensorflow.keras.models import load_model
from .features.sequence import SequenceFeatureExtractor
from .features.cnn_input import CNNInputProcessor
from .utils import extract_window_sequences
import matplotlib.pyplot as plt
import joblib
class PRFPredictor:
def __init__(self, model_dir=None):
"""
初始化PRF预测器
Args:
model_dir: 模型目录路径可选
"""
if model_dir is None:
from pkg_resources import resource_filename
model_dir = resource_filename('FScanpy', 'pretrained')
try:
# 加载模型 - 使用新的命名约定
self.short_model = self._load_pickle(os.path.join(model_dir, 'short.pkl')) # HistGB模型
self.long_model = self._load_pickle(os.path.join(model_dir, 'long.pkl')) # BiLSTM-CNN模型
# 初始化特征提取器和CNN处理器使用与训练时相同的序列长度
self.short_seq_length = 33 # HistGB使用的序列长度
self.long_seq_length = 399 # BiLSTM-CNN使用的序列长度
# 初始化特征提取器和CNN输入处理器
self.feature_extractor = SequenceFeatureExtractor(seq_length=self.short_seq_length)
self.cnn_processor = CNNInputProcessor(max_length=self.long_seq_length)
# 检测模型类型以优化预测性能
self._detect_model_types()
except FileNotFoundError as e:
raise FileNotFoundError(f"无法找到模型文件: {str(e)}。请确保模型文件 'short.pkl''long.pkl' 存在于 {model_dir}")
except Exception as e:
raise Exception(f"加载模型出错: {str(e)}")
def _load_pickle(self, path):
"""安全加载pickle文件"""
try:
return joblib.load(path)
except Exception as e:
raise FileNotFoundError(f"无法加载模型文件 {path}: {str(e)}")
def _detect_model_types(self):
"""检测模型类型以优化预测性能"""
self.short_is_sklearn = hasattr(self.short_model, 'predict_proba')
self.long_is_sklearn = hasattr(self.long_model, 'predict_proba')
def _predict_model(self, model, features, is_sklearn, seq_length):
"""统一的模型预测方法"""
try:
if is_sklearn:
# sklearn模型使用特征向量
if isinstance(features, np.ndarray) and features.ndim > 1:
features = features.flatten()
features_2d = np.array([features])
pred = model.predict_proba(features_2d)
return pred[0][1]
else:
# 深度学习模型
if seq_length == self.long_seq_length:
# 对于长序列使用CNN处理器
model_input = self.cnn_processor.prepare_sequence(features)
else:
# 对于短序列,转换为数值编码
base_to_num = {'A': 1, 'T': 2, 'G': 3, 'C': 4, 'N': 0}
seq_numeric = [base_to_num.get(base, 0) for base in features.upper()]
model_input = np.array(seq_numeric).reshape(1, len(seq_numeric), 1)
# 统一的预测调用
try:
pred = model.predict(model_input, verbose=0)
except TypeError:
pred = model.predict(model_input)
# 处理预测结果
if isinstance(pred, list):
pred = pred[0]
if hasattr(pred, 'shape') and len(pred.shape) > 1 and pred.shape[1] > 1:
return pred[0][1]
else:
return pred[0][0] if hasattr(pred[0], '__getitem__') else pred[0]
except Exception as e:
raise Exception(f"模型预测失败: {str(e)}")
def predict_single_position(self, fs_period, full_seq, short_threshold=0.1, ensemble_weight=0.4):
'''
预测单个位置的PRF状态
Args:
fs_period: 33bp序列 (short模型使用)
full_seq: 完整序列 (long模型使用)
short_threshold: short模型的概率阈值 (默认为0.1)
ensemble_weight: short模型在集成中的权重 (默认为0.4long权重为0.6)
Returns:
dict: 包含预测概率的字典
'''
try:
# 验证权重参数
if not (0.0 <= ensemble_weight <= 1.0):
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
long_weight = 1.0 - ensemble_weight
# 处理序列长度
if len(fs_period) > self.short_seq_length:
fs_period = self.feature_extractor.trim_sequence(fs_period, self.short_seq_length)
# Short模型预测 (HistGB)
try:
if self.short_is_sklearn:
short_features = self.feature_extractor.extract_features(fs_period)
short_prob = self._predict_model(self.short_model, short_features, True, self.short_seq_length)
else:
short_prob = self._predict_model(self.short_model, fs_period, False, self.short_seq_length)
except Exception as e:
print(f"Short模型预测时出错: {str(e)}")
short_prob = 0.0
# 如果short概率低于阈值则跳过long模型
if short_prob < short_threshold:
return {
'Short_Probability': short_prob,
'Long_Probability': 0.0,
'Ensemble_Probability': 0.0,
'Ensemble_Weights': f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}'
}
# Long模型预测 (BiLSTM-CNN)
try:
if self.long_is_sklearn:
long_features = self.feature_extractor.extract_features(full_seq)
long_prob = self._predict_model(self.long_model, long_features, True, self.long_seq_length)
else:
long_prob = self._predict_model(self.long_model, full_seq, False, self.long_seq_length)
except Exception as e:
print(f"Long模型预测时出错: {str(e)}")
long_prob = 0.0
# 计算集成概率
try:
ensemble_prob = ensemble_weight * short_prob + long_weight * long_prob
except Exception as e:
print(f"计算集成概率时出错: {str(e)}")
ensemble_prob = (short_prob + long_prob) / 2
return {
'Short_Probability': short_prob,
'Long_Probability': long_prob,
'Ensemble_Probability': ensemble_prob,
'Ensemble_Weights': f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}'
}
except Exception as e:
raise Exception(f"预测过程出错: {str(e)}")
def predict_sequence(self, sequence, window_size=3, short_threshold=0.1, ensemble_weight=0.4):
"""
预测完整序列中的PRF位点滑动窗口方法
Args:
sequence: 输入DNA序列
window_size: 滑动窗口大小 (默认为3)
short_threshold: short模型概率阈值 (默认为0.1)
ensemble_weight: short模型在集成中的权重 (默认为0.4)
Returns:
pd.DataFrame: 包含预测结果的DataFrame
"""
if window_size < 1:
raise ValueError("窗口大小必须大于等于1")
if short_threshold < 0:
raise ValueError("short模型阈值必须大于等于0")
if not (0.0 <= ensemble_weight <= 1.0):
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
results = []
long_weight = 1.0 - ensemble_weight
try:
# 确保序列为字符串并转换为大写
sequence = str(sequence).upper()
# 滑动窗口预测
for pos in range(0, len(sequence) - 2, window_size):
# 提取窗口序列
fs_period, full_seq = extract_window_sequences(sequence, pos)
if fs_period is None or full_seq is None:
continue
# 预测并记录结果
pred = self.predict_single_position(fs_period, full_seq, short_threshold, ensemble_weight)
pred.update({
'Position': pos,
'Codon': sequence[pos:pos+3],
'Short_Sequence': fs_period, # 更清晰的命名
'Long_Sequence': full_seq # 更清晰的命名
})
results.append(pred)
# 创建结果DataFrame
results_df = pd.DataFrame(results)
return results_df
except Exception as e:
raise Exception(f"序列预测过程出错: {str(e)}")
def plot_sequence_prediction(self, sequence, window_size=3, short_threshold=0.65,
long_threshold=0.8, ensemble_weight=0.4, title=None, save_path=None,
figsize=(12, 6), dpi=300):
"""
绘制序列预测结果的移码概率图
Args:
sequence: 输入DNA序列
window_size: 滑动窗口大小 (默认为3)
short_threshold: Short模型(HistGB)过滤阈值 (默认为0.65)
long_threshold: Long模型(BiLSTM-CNN)过滤阈值 (默认为0.8)
ensemble_weight: Short模型在集成中的权重 (默认为0.4)
title: 图片标题 (可选)
save_path: 保存路径 (可选如果提供则保存图片)
figsize: 图片尺寸 (默认为(12, 6))
dpi: 图片分辨率 (默认为300)
Returns:
tuple: (pd.DataFrame, matplotlib.figure.Figure) 预测结果和图形对象
"""
try:
# 验证权重参数
if not (0.0 <= ensemble_weight <= 1.0):
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
long_weight = 1.0 - ensemble_weight
# 获取预测结果 - 使用新的方法名
results_df = self.predict_sequence(sequence, window_size=window_size,
short_threshold=0.1, ensemble_weight=ensemble_weight)
if results_df.empty:
raise ValueError("预测结果为空,请检查输入序列")
# 获取序列长度
seq_length = len(sequence)
# 计算显示宽度
prob_width = max(1, seq_length // 300) # 概率标记的宽度
# 创建图形,包含两个子图
fig = plt.figure(figsize=figsize)
gs = fig.add_gridspec(2, 1, height_ratios=[0.15, 1], hspace=0.3)
# 设置标题
if title:
fig.suptitle(title, y=0.95, fontsize=12)
else:
fig.suptitle(f'序列移码概率预测结果 (权重 {ensemble_weight:.1f}:{long_weight:.1f})', y=0.95, fontsize=12)
# 预测概率热图
ax0 = fig.add_subplot(gs[0])
prob_data = np.zeros((1, seq_length))
# 应用双重阈值过滤
for _, row in results_df.iterrows():
pos = int(row['Position'])
if (row['Short_Probability'] >= short_threshold and
row['Long_Probability'] >= long_threshold):
# 为每个满足阈值的位置设置概率值
start = max(0, pos - prob_width//2)
end = min(seq_length, pos + prob_width//2 + 1)
prob_data[0, start:end] = row['Ensemble_Probability']
im = ax0.imshow(prob_data, cmap='Reds', aspect='auto', vmin=0, vmax=1,
interpolation='nearest')
ax0.set_xticks([])
ax0.set_yticks([])
ax0.set_title(f'预测概率热图 (Short≥{short_threshold}, Long≥{long_threshold})',
pad=5, fontsize=10)
# 主图(条形图)
ax1 = fig.add_subplot(gs[1])
# 应用过滤阈值
filtered_probs = results_df['Ensemble_Probability'].copy()
mask = ((results_df['Short_Probability'] < short_threshold) |
(results_df['Long_Probability'] < long_threshold))
filtered_probs[mask] = 0
# 绘制条形图
bars = ax1.bar(results_df['Position'], filtered_probs,
alpha=0.7, color='darkred', width=max(1, window_size))
# 设置x轴刻度
step = max(seq_length // 10, 50)
x_ticks = np.arange(0, seq_length, step)
ax1.set_xticks(x_ticks)
ax1.tick_params(axis='x', rotation=45)
# 设置标签和标题
ax1.set_xlabel('序列位置 (bp)', fontsize=10)
ax1.set_ylabel('移码概率', fontsize=10)
ax1.set_title(f'移码概率分布 (集成权重 {ensemble_weight:.1f}:{long_weight:.1f})', fontsize=11)
# 设置y轴范围
ax1.set_ylim(0, 1)
# 添加网格
ax1.grid(True, alpha=0.3)
# 添加阈值和权重说明
info_text = (f'过滤阈值: Short≥{short_threshold}, Long≥{long_threshold}\n'
f'集成权重: Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}')
ax1.text(0.02, 0.95, info_text, transform=ax1.transAxes,
fontsize=9, verticalalignment='top',
bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8))
# 确保所有子图的x轴范围一致
for ax in [ax0, ax1]:
ax.set_xlim(-1, seq_length)
# 调整布局
plt.tight_layout()
# 如果提供了保存路径,则保存图片
if save_path:
plt.savefig(save_path, dpi=dpi, bbox_inches='tight')
# 同时保存PDF版本
if save_path.endswith('.png'):
pdf_path = save_path.replace('.png', '.pdf')
plt.savefig(pdf_path, bbox_inches='tight')
print(f"图片已保存至: {save_path}")
return results_df, fig
except Exception as e:
raise Exception(f"绘制序列预测图时出错: {str(e)}")
def predict_regions(self, sequences, short_threshold=0.1, ensemble_weight=0.4):
'''
预测区域序列批量预测已知的399bp序列
Args:
sequences: 399bp序列或包含399bp序列的DataFrame/Series/list
short_threshold: short模型概率阈值 (默认为0.1)
ensemble_weight: short模型在集成中的权重 (默认为0.4)
Returns:
DataFrame: 包含所有序列预测概率的DataFrame
'''
try:
# 验证权重参数
if not (0.0 <= ensemble_weight <= 1.0):
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
# 统一输入格式
if isinstance(sequences, (pd.DataFrame, pd.Series)):
sequences = sequences.tolist()
elif isinstance(sequences, str):
sequences = [sequences]
results = []
for i, seq399 in enumerate(sequences):
try:
# 从399bp序列中截取中心的33bp (short模型使用)
seq33 = self._extract_center_sequence(seq399, target_length=self.short_seq_length)
# 使用统一的预测方法
pred_result = self.predict_single_position(seq33, seq399, short_threshold, ensemble_weight)
pred_result.update({
'Short_Sequence': seq33,
'Long_Sequence': seq399
})
results.append(pred_result)
except Exception as e:
print(f"处理第 {i+1} 个序列时出错: {str(e)}")
long_weight = 1.0 - ensemble_weight
results.append({
'Short_Probability': 0.0,
'Long_Probability': 0.0,
'Ensemble_Probability': 0.0,
'Ensemble_Weights': f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}',
'Short_Sequence': self._extract_center_sequence(seq399, target_length=self.short_seq_length) if len(seq399) >= self.short_seq_length else seq399,
'Long_Sequence': seq399
})
return pd.DataFrame(results)
except Exception as e:
raise Exception(f"区域预测过程出错: {str(e)}")
def _extract_center_sequence(self, sequence, target_length=33):
"""从序列中心位置提取指定长度的子序列"""
# 确保序列为字符串
sequence = str(sequence).upper()
# 如果序列长度小于目标长度,返回原序列
if len(sequence) <= target_length:
return sequence
# 计算中心位置
center = len(sequence) // 2
half_target = target_length // 2
# 提取中心序列
start = center - half_target
end = start + target_length
# 边界检查
if start < 0:
start = 0
end = target_length
elif end > len(sequence):
end = len(sequence)
start = end - target_length
return sequence[start:end]
# 兼容性方法(向后兼容,但标记为废弃)
def predict_full(self, sequence, window_size=3, short_threshold=0.1, short_weight=0.4, plot=False):
"""
已废弃请使用 predict_sequence() 方法
向后兼容的方法内部调用新的 predict_sequence()
"""
import warnings
warnings.warn("predict_full() 已废弃,请使用 predict_sequence() 方法", DeprecationWarning, stacklevel=2)
# 调用新方法并添加兼容性字段
results_df = self.predict_sequence(sequence, window_size, short_threshold, short_weight)
# 添加兼容性字段
if 'Ensemble_Probability' in results_df.columns:
results_df['Voting_Probability'] = results_df['Ensemble_Probability']
results_df['Weighted_Probability'] = results_df['Ensemble_Probability']
if 'Ensemble_Weights' in results_df.columns:
results_df['Weight_Info'] = results_df['Ensemble_Weights']
if 'Short_Sequence' in results_df.columns:
results_df['33bp'] = results_df['Short_Sequence']
if 'Long_Sequence' in results_df.columns:
results_df['399bp'] = results_df['Long_Sequence']
if plot:
# 如果需要绘图,调用绘图方法
_, fig = self.plot_sequence_prediction(sequence, window_size, 0.65, 0.8, short_weight)
return results_df, fig
return results_df
def predict_region(self, seq, short_threshold=0.1, short_weight=0.4):
"""
已废弃请使用 predict_regions() 方法
向后兼容的方法内部调用新的 predict_regions()
"""
import warnings
warnings.warn("predict_region() 已废弃,请使用 predict_regions() 方法", DeprecationWarning, stacklevel=2)
# 调用新方法并添加兼容性字段
results_df = self.predict_regions(seq, short_threshold, short_weight)
# 添加兼容性字段
if 'Ensemble_Probability' in results_df.columns:
results_df['Voting_Probability'] = results_df['Ensemble_Probability']
results_df['Weighted_Probability'] = results_df['Ensemble_Probability']
if 'Ensemble_Weights' in results_df.columns:
results_df['Weight_Info'] = results_df['Ensemble_Weights']
if 'Short_Sequence' in results_df.columns:
results_df['33bp'] = results_df['Short_Sequence']
if 'Long_Sequence' in results_df.columns:
results_df['399bp'] = results_df['Long_Sequence']
import os
import pickle
import numpy as np
import pandas as pd
from tensorflow.keras.models import load_model
from .features.sequence import SequenceFeatureExtractor
from .features.cnn_input import CNNInputProcessor
from .utils import extract_window_sequences
import matplotlib.pyplot as plt
import joblib
class PRFPredictor:
def __init__(self, model_dir=None):
"""
初始化PRF预测器
Args:
model_dir: 模型目录路径可选
"""
if model_dir is None:
from pkg_resources import resource_filename
model_dir = resource_filename('FScanpy', 'pretrained')
try:
# 加载模型 - 使用新的命名约定
self.short_model = self._load_pickle(os.path.join(model_dir, 'short.pkl')) # HistGB模型
self.long_model = self._load_pickle(os.path.join(model_dir, 'long.pkl')) # BiLSTM-CNN模型
# 初始化特征提取器和CNN处理器使用与训练时相同的序列长度
self.short_seq_length = 33 # HistGB使用的序列长度
self.long_seq_length = 399 # BiLSTM-CNN使用的序列长度
# 初始化特征提取器和CNN输入处理器
self.feature_extractor = SequenceFeatureExtractor(seq_length=self.short_seq_length)
self.cnn_processor = CNNInputProcessor(max_length=self.long_seq_length)
# 检测模型类型以优化预测性能
self._detect_model_types()
except FileNotFoundError as e:
raise FileNotFoundError(f"无法找到模型文件: {str(e)}。请确保模型文件 'short.pkl''long.pkl' 存在于 {model_dir}")
except Exception as e:
raise Exception(f"加载模型出错: {str(e)}")
def _load_pickle(self, path):
"""安全加载pickle文件"""
try:
return joblib.load(path)
except Exception as e:
raise FileNotFoundError(f"无法加载模型文件 {path}: {str(e)}")
def _detect_model_types(self):
"""检测模型类型以优化预测性能"""
self.short_is_sklearn = hasattr(self.short_model, 'predict_proba')
self.long_is_sklearn = hasattr(self.long_model, 'predict_proba')
def _predict_model(self, model, features, is_sklearn, seq_length):
"""统一的模型预测方法"""
try:
if is_sklearn:
# sklearn模型使用特征向量
if isinstance(features, np.ndarray) and features.ndim > 1:
features = features.flatten()
features_2d = np.array([features])
pred = model.predict_proba(features_2d)
return pred[0][1]
else:
# 深度学习模型
if seq_length == self.long_seq_length:
# 对于长序列使用CNN处理器
model_input = self.cnn_processor.prepare_sequence(features)
else:
# 对于短序列,转换为数值编码
base_to_num = {'A': 1, 'T': 2, 'G': 3, 'C': 4, 'N': 0}
seq_numeric = [base_to_num.get(base, 0) for base in features.upper()]
model_input = np.array(seq_numeric).reshape(1, len(seq_numeric), 1)
# 统一的预测调用
try:
pred = model.predict(model_input, verbose=0)
except TypeError:
pred = model.predict(model_input)
# 处理预测结果
if isinstance(pred, list):
pred = pred[0]
if hasattr(pred, 'shape') and len(pred.shape) > 1 and pred.shape[1] > 1:
return pred[0][1]
else:
return pred[0][0] if hasattr(pred[0], '__getitem__') else pred[0]
except Exception as e:
raise Exception(f"模型预测失败: {str(e)}")
def predict_single_position(self, fs_period, full_seq, short_threshold=0.1, ensemble_weight=0.4):
'''
预测单个位置的PRF状态
Args:
fs_period: 33bp序列 (short模型使用)
full_seq: 完整序列 (long模型使用)
short_threshold: short模型的概率阈值 (默认为0.1)
ensemble_weight: short模型在集成中的权重 (默认为0.4long权重为0.6)
Returns:
dict: 包含预测概率的字典
'''
try:
# 验证权重参数
if not (0.0 <= ensemble_weight <= 1.0):
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
long_weight = 1.0 - ensemble_weight
# 处理序列长度
if len(fs_period) > self.short_seq_length:
fs_period = self.feature_extractor.trim_sequence(fs_period, self.short_seq_length)
# Short模型预测 (HistGB)
try:
if self.short_is_sklearn:
short_features = self.feature_extractor.extract_features(fs_period)
short_prob = self._predict_model(self.short_model, short_features, True, self.short_seq_length)
else:
short_prob = self._predict_model(self.short_model, fs_period, False, self.short_seq_length)
except Exception as e:
print(f"Short模型预测时出错: {str(e)}")
short_prob = 0.0
# 如果short概率低于阈值则跳过long模型
if short_prob < short_threshold:
return {
'Short_Probability': short_prob,
'Long_Probability': 0.0,
'Ensemble_Probability': 0.0,
'Ensemble_Weights': f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}'
}
# Long模型预测 (BiLSTM-CNN)
try:
if self.long_is_sklearn:
long_features = self.feature_extractor.extract_features(full_seq)
long_prob = self._predict_model(self.long_model, long_features, True, self.long_seq_length)
else:
long_prob = self._predict_model(self.long_model, full_seq, False, self.long_seq_length)
except Exception as e:
print(f"Long模型预测时出错: {str(e)}")
long_prob = 0.0
# 计算集成概率
try:
ensemble_prob = ensemble_weight * short_prob + long_weight * long_prob
except Exception as e:
print(f"计算集成概率时出错: {str(e)}")
ensemble_prob = (short_prob + long_prob) / 2
return {
'Short_Probability': short_prob,
'Long_Probability': long_prob,
'Ensemble_Probability': ensemble_prob,
'Ensemble_Weights': f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}'
}
except Exception as e:
raise Exception(f"预测过程出错: {str(e)}")
def predict_sequence(self, sequence, window_size=3, short_threshold=0.1, ensemble_weight=0.4):
"""
预测完整序列中的PRF位点滑动窗口方法
Args:
sequence: 输入DNA序列
window_size: 滑动窗口大小 (默认为3)
short_threshold: short模型概率阈值 (默认为0.1)
ensemble_weight: short模型在集成中的权重 (默认为0.4)
Returns:
pd.DataFrame: 包含预测结果的DataFrame
"""
if window_size < 1:
raise ValueError("窗口大小必须大于等于1")
if short_threshold < 0:
raise ValueError("short模型阈值必须大于等于0")
if not (0.0 <= ensemble_weight <= 1.0):
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
results = []
long_weight = 1.0 - ensemble_weight
try:
# 确保序列为字符串并转换为大写
sequence = str(sequence).upper()
# 滑动窗口预测
for pos in range(0, len(sequence) - 2, window_size):
# 提取窗口序列
fs_period, full_seq = extract_window_sequences(sequence, pos)
if fs_period is None or full_seq is None:
continue
# 预测并记录结果
pred = self.predict_single_position(fs_period, full_seq, short_threshold, ensemble_weight)
pred.update({
'Position': pos,
'Codon': sequence[pos:pos+3],
'Short_Sequence': fs_period, # 更清晰的命名
'Long_Sequence': full_seq # 更清晰的命名
})
results.append(pred)
# 创建结果DataFrame
results_df = pd.DataFrame(results)
return results_df
except Exception as e:
raise Exception(f"序列预测过程出错: {str(e)}")
def plot_sequence_prediction(self, sequence, window_size=3, short_threshold=0.65,
long_threshold=0.8, ensemble_weight=0.4, title=None, save_path=None,
figsize=(12, 8), dpi=300):
"""
Plot sequence PRF prediction results
Args:
sequence: Input DNA sequence
window_size: Sliding window size (default: 3)
short_threshold: Short model (HistGB) filtering threshold (default: 0.65)
long_threshold: Long model (BiLSTM-CNN) filtering threshold (default: 0.8)
ensemble_weight: Weight of short model in ensemble (default: 0.4)
title: Plot title (optional)
save_path: Save path (optional, saves plot if provided)
figsize: Figure size (default: (12, 8))
dpi: Figure resolution (default: 300)
Returns:
tuple: (pd.DataFrame, matplotlib.figure.Figure) prediction results and figure object
"""
try:
# Validate weight parameter
if not (0.0 <= ensemble_weight <= 1.0):
raise ValueError("ensemble_weight must be between 0.0 and 1.0")
long_weight = 1.0 - ensemble_weight
# Get prediction results
results_df = self.predict_sequence(sequence, window_size=window_size,
short_threshold=0.1, ensemble_weight=ensemble_weight)
if results_df.empty:
raise ValueError("Prediction results are empty, please check input sequence")
# Get sequence length
seq_length = len(sequence)
# Calculate display width
desired_visual_width = max(3, seq_length // 100) # FS site width ~1% of sequence length
prob_width = max(1, desired_visual_width // 3) # Prediction probability width is 1/3 of FS site width
# Create figure with three subplots, set height ratios
fig = plt.figure(figsize=figsize)
# Set title
if title:
fig.suptitle(title, y=0.95, fontsize=10)
else:
fig.suptitle(f'PRF Prediction Results (Weights {ensemble_weight:.1f}:{long_weight:.1f})', y=0.95, fontsize=10)
# Adjust subplot ratios, make top two heatmaps smaller
gs = fig.add_gridspec(3, 1, height_ratios=[0.1, 0.1, 1], hspace=0.2)
# FS site heatmap - using fixed width, no blur effect
ax0 = fig.add_subplot(gs[0])
fs_data = np.zeros((1, seq_length))
# Note: No actual FS site information in sliding window prediction, so keep empty or show predicted sites
# Show high-confidence predictions as potential FS sites
for _, row in results_df.iterrows():
pos = int(row['Position'])
if (row['Short_Probability'] >= short_threshold and
row['Long_Probability'] >= long_threshold and
row['Ensemble_Probability'] >= 0.8): # High confidence threshold
half_width = desired_visual_width // 2
start_pos = max(0, pos - half_width)
end_pos = min(seq_length, pos + half_width + 1)
fs_data[0, start_pos:end_pos] = 1 # Use fixed value, no gradient
ax0.imshow(fs_data, cmap='Reds', aspect='auto', interpolation='nearest')
ax0.set_xticks([])
ax0.set_yticks([])
ax0.set_title('FS site', pad=2, fontsize=8)
# Prediction probability heatmap - using fixed width to display probabilities
ax1 = fig.add_subplot(gs[1])
prob_data = np.zeros((1, seq_length))
# Apply dual threshold filtering
for _, row in results_df.iterrows():
pos = int(row['Position'])
if (row['Short_Probability'] >= short_threshold and
row['Long_Probability'] >= long_threshold):
# Set fixed width for each probability value
start = max(0, pos - prob_width//2)
end = min(seq_length, pos + prob_width//2 + 1)
prob_data[0, start:end] = row['Ensemble_Probability']
im = ax1.imshow(prob_data, cmap='Reds', aspect='auto', vmin=0, vmax=1, interpolation='nearest')
ax1.set_xticks([])
ax1.set_yticks([])
ax1.set_title('Prediction', pad=2, fontsize=8)
# Main plot (bar chart)
ax2 = fig.add_subplot(gs[2])
# Apply filtering thresholds
filtered_probs = results_df['Ensemble_Probability'].copy()
mask = ((results_df['Short_Probability'] < short_threshold) |
(results_df['Long_Probability'] < long_threshold))
filtered_probs[mask] = 0
# Draw bar chart - use black color and alpha=0.6 to match prediction_sample style
ax2.bar(results_df['Position'], filtered_probs,
alpha=0.6, color='black', width=1.0)
# Set x-axis ticks
step = max(seq_length // 10, 50)
ax2.set_xticks(np.arange(0, seq_length, step))
ax2.tick_params(axis='x', rotation=45)
# Set labels
ax2.set_xlabel('Position')
ax2.set_ylabel('Probability')
# Set y-axis range
ax2.set_ylim(0, 1)
# Add grid
ax2.grid(True, alpha=0.3)
# Ensure all subplots have consistent x-axis range
for ax in [ax0, ax1, ax2]:
ax.set_xlim(-1, seq_length)
# Adjust layout
plt.tight_layout()
# Save plot if save path is provided
if save_path:
plt.savefig(save_path, dpi=dpi, bbox_inches='tight')
# Also save PDF version
if save_path.endswith('.png'):
pdf_path = save_path.replace('.png', '.pdf')
plt.savefig(pdf_path, bbox_inches='tight')
print(f"Plot saved to: {save_path}")
return results_df, fig
except Exception as e:
raise Exception(f"Error plotting sequence prediction: {str(e)}")
def predict_regions(self, sequences, short_threshold=0.1, ensemble_weight=0.4):
'''
Predict region sequences (batch prediction of known 399bp sequences)
Args:
sequences: 399bp sequences or DataFrame/Series/list containing 399bp sequences
short_threshold: Short model probability threshold (default: 0.1)
ensemble_weight: Weight of short model in ensemble (default: 0.4)
Returns:
DataFrame: DataFrame containing prediction probabilities for all sequences
'''
try:
# Validate weight parameter
if not (0.0 <= ensemble_weight <= 1.0):
raise ValueError("ensemble_weight must be between 0.0 and 1.0")
# Unify input format
if isinstance(sequences, (pd.DataFrame, pd.Series)):
sequences = sequences.tolist()
elif isinstance(sequences, str):
sequences = [sequences]
results = []
for i, seq399 in enumerate(sequences):
try:
# Extract central 33bp from 399bp sequence (for short model use)
seq33 = self._extract_center_sequence(seq399, target_length=self.short_seq_length)
# Use unified prediction method
pred_result = self.predict_single_position(seq33, seq399, short_threshold, ensemble_weight)
pred_result.update({
'Short_Sequence': seq33,
'Long_Sequence': seq399
})
results.append(pred_result)
except Exception as e:
print(f"Error processing sequence {i+1}: {str(e)}")
long_weight = 1.0 - ensemble_weight
results.append({
'Short_Probability': 0.0,
'Long_Probability': 0.0,
'Ensemble_Probability': 0.0,
'Ensemble_Weights': f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}',
'Short_Sequence': self._extract_center_sequence(seq399, target_length=self.short_seq_length) if len(seq399) >= self.short_seq_length else seq399,
'Long_Sequence': seq399
})
return pd.DataFrame(results)
except Exception as e:
raise Exception(f"Error in region prediction process: {str(e)}")
def _extract_center_sequence(self, sequence, target_length=33):
"""Extract subsequence of specified length from center position of sequence"""
# Ensure sequence is string
sequence = str(sequence).upper()
# If sequence length is less than target length, return original sequence
if len(sequence) <= target_length:
return sequence
# Calculate center position
center = len(sequence) // 2
half_target = target_length // 2
# Extract center sequence
start = center - half_target
end = start + target_length
# Boundary check
if start < 0:
start = 0
end = target_length
elif end > len(sequence):
end = len(sequence)
start = end - target_length
return sequence[start:end]
# 兼容性方法(向后兼容,但标记为废弃)
def predict_full(self, sequence, window_size=3, short_threshold=0.1, short_weight=0.4, plot=False):
"""
已废弃请使用 predict_sequence() 方法
向后兼容的方法内部调用新的 predict_sequence()
"""
import warnings
warnings.warn("predict_full() 已废弃,请使用 predict_sequence() 方法", DeprecationWarning, stacklevel=2)
# 调用新方法并添加兼容性字段
results_df = self.predict_sequence(sequence, window_size, short_threshold, short_weight)
# 添加兼容性字段
if 'Ensemble_Probability' in results_df.columns:
results_df['Voting_Probability'] = results_df['Ensemble_Probability']
results_df['Weighted_Probability'] = results_df['Ensemble_Probability']
if 'Ensemble_Weights' in results_df.columns:
results_df['Weight_Info'] = results_df['Ensemble_Weights']
if 'Short_Sequence' in results_df.columns:
results_df['33bp'] = results_df['Short_Sequence']
if 'Long_Sequence' in results_df.columns:
results_df['399bp'] = results_df['Long_Sequence']
if plot:
# 如果需要绘图,调用绘图方法
_, fig = self.plot_sequence_prediction(sequence, window_size, 0.65, 0.8, short_weight)
return results_df, fig
return results_df
def predict_region(self, seq, short_threshold=0.1, short_weight=0.4):
"""
已废弃请使用 predict_regions() 方法
向后兼容的方法内部调用新的 predict_regions()
"""
import warnings
warnings.warn("predict_region() 已废弃,请使用 predict_regions() 方法", DeprecationWarning, stacklevel=2)
# 调用新方法并添加兼容性字段
results_df = self.predict_regions(seq, short_threshold, short_weight)
# 添加兼容性字段
if 'Ensemble_Probability' in results_df.columns:
results_df['Voting_Probability'] = results_df['Ensemble_Probability']
results_df['Weighted_Probability'] = results_df['Ensemble_Probability']
if 'Ensemble_Weights' in results_df.columns:
results_df['Weight_Info'] = results_df['Ensemble_Weights']
if 'Short_Sequence' in results_df.columns:
results_df['33bp'] = results_df['Short_Sequence']
if 'Long_Sequence' in results_df.columns:
results_df['399bp'] = results_df['Long_Sequence']
return results_df

View File

@ -6,65 +6,58 @@
"source": [
"# FScanpy \n",
"\n",
"这个 Notebook 展示了如何使用 FScanpy 的真实测试数据进行完整的 PRF 位点预测分析,包括:\n",
"This notebook demonstrates how to use FScanpy with real test data for complete PRF site prediction analysis, including:\n",
"\n",
"## 🎯 完整工作流程\n",
"1. **加载测试数据** - 使用内置的真实测试数据\n",
"2. **FScanR 分析** - 从 BLASTX 结果识别潜在 PRF 位点\n",
"3. **序列提取** - 提取 PRF 位点周围的序列\n",
"4. **FScanpy 预测** - 使用机器学习模型预测概率\n",
"5. **结果可视化** - 使用内置绘图函数生成预测结果图表\n",
"6. **序列级预测演示** - 完整序列的滑动窗口分析\n",
"## 🎯 Complete Workflow\n",
"1. **Load Test Data** - Use built-in real test data\n",
"2. **FScanR Analysis** - Identify potential PRF sites from BLASTX results\n",
"3. **Sequence Extraction** - Extract sequences around PRF sites\n",
"4. **FScanpy Prediction** - Use machine learning models to predict probabilities\n",
"5. **Results Visualization** - Generate prediction result plots using built-in plotting functions\n",
"6. **Sequence-level Prediction Demo** - Sliding window analysis of complete sequences\n",
"\n",
"## 📊 数据说明\n",
"- **blastx_example.xlsx**: 真实BLASTX比对结果\n",
"- **mrna_example.fasta**: 真实mRNA序列数据\n",
"- **region_example.csv**: 单独对某个位点进行预测的样本"
"## 📊 Data Description\n",
"- **blastx_example.xlsx**: Real BLASTX alignment results\n",
"- **mrna_example.fasta**: Real mRNA sequence data\n",
"- **region_example.csv**: Sample for individual site prediction"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 📦 环境准备和数据加载"
"## 📦 Environment Setup and Data Loading"
]
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"✅ 环境准备完成!\n",
"📋 可用的测试数据:\n"
"ename": "ImportError",
"evalue": "cannot import name 'PRFPredictor' from 'FScanpy' (unknown location)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[3], line 6\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpyplot\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mplt\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;66;03m# Import FScanpy related modules\u001b[39;00m\n\u001b[0;32m----> 6\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mFScanpy\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m PRFPredictor, predict_prf, plot_prf_prediction\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mFScanpy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdata\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m get_test_data_path, list_test_data\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mFScanpy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m fscanr, extract_prf_regions\n",
"\u001b[0;31mImportError\u001b[0m: cannot import name 'PRFPredictor' from 'FScanpy' (unknown location)"
]
},
{
"data": {
"text/plain": [
"['blastx_example.xlsx', 'mrna_example.fasta', 'region_example.csv']"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 导入必要的库\n",
"# Import necessary libraries\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# 导入FScanpy相关模块\n",
"# Import FScanpy related modules\n",
"from FScanpy import PRFPredictor, predict_prf, plot_prf_prediction\n",
"from FScanpy.data import get_test_data_path, list_test_data\n",
"from FScanpy.utils import fscanr, extract_prf_regions\n",
"\n",
"print(\"✅ 环境准备完成!\")\n",
"print(\"📋 可用的测试数据:\")\n",
"print(\"✅ Environment setup complete!\")\n",
"print(\"📋 Available test data:\")\n",
"list_test_data()"
]
},
@ -72,9 +65,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. 加载和探索测试数据\n",
"## 1. Load and Explore Test Data\n",
"\n",
"首先加载 FScanpy 提供的真实测试数据,了解数据结构。"
"First, load the real test data provided by FScanpy to understand the data structure."
]
},
{
@ -107,25 +100,25 @@
}
],
"source": [
"# 获取测试数据路径\n",
"# Get test data paths\n",
"blastx_file = get_test_data_path('blastx_example.xlsx')\n",
"mrna_file = get_test_data_path('mrna_example.fasta')\n",
"region_file = get_test_data_path('region_example.csv')\n",
"\n",
"print(f\"📁 数据文件路径:\")\n",
"print(f\" BLASTX数据: {blastx_file}\")\n",
"print(f\" mRNA序列: {mrna_file}\")\n",
"print(f\" 验证区域: {region_file}\")\n",
"print(f\"📁 Data file paths:\")\n",
"print(f\" BLASTX data: {blastx_file}\")\n",
"print(f\" mRNA sequences: {mrna_file}\")\n",
"print(f\" Validation regions: {region_file}\")\n",
"\n",
"# 加载BLASTX数据\n",
"# Load BLASTX data\n",
"blastx_data = pd.read_excel(blastx_file)\n",
"print(f\"\\n🧬 BLASTX数据概览:\")\n",
"print(f\" 数据形状: {blastx_data.shape}\")\n",
"print(f\" 列名: {list(blastx_data.columns)}\")\n",
"print(f\" 唯一序列数: {blastx_data['DNA_seqid'].nunique()}\")\n",
"print(f\"\\n🧬 BLASTX data overview:\")\n",
"print(f\" Data shape: {blastx_data.shape}\")\n",
"print(f\" Column names: {list(blastx_data.columns)}\")\n",
"print(f\" Unique sequences: {blastx_data['DNA_seqid'].nunique()}\")\n",
"\n",
"# 显示前几行\n",
"print(\"\\n📊 BLASTX数据示例:\")\n",
"# Display first few rows\n",
"print(\"\\n📊 BLASTX data examples:\")\n",
"display_cols = ['DNA_seqid', 'Pep_seqid', 'pident', 'length', 'evalue', 'qframe']\n",
"print(blastx_data[display_cols].head())"
]
@ -163,21 +156,21 @@
}
],
"source": [
"# 加载验证区域数据\n",
"# Load validation region data\n",
"region_data = pd.read_csv(region_file)\n",
"print(f\"🎯 验证区域数据概览:\")\n",
"print(f\" 数据形状: {region_data.shape}\")\n",
"print(f\" 列名: {list(region_data.columns)}\")\n",
"print(f\" 数据来源: {region_data['source'].value_counts().to_dict()}\")\n",
"print(f\"🎯 Validation region data overview:\")\n",
"print(f\" Data shape: {region_data.shape}\")\n",
"print(f\" Column names: {list(region_data.columns)}\")\n",
"print(f\" Data sources: {region_data['source'].value_counts().to_dict()}\")\n",
"\n",
"print(\"\\n📋 验证区域数据示例:\")\n",
"print(\"\\n📋 Validation region data examples:\")\n",
"display_cols = ['fs_position', 'DNA_seqid', 'label', 'source', 'FS_type']\n",
"print(region_data[display_cols].head())\n",
"\n",
"# 统计分析\n",
"print(f\"\\n📈 标签分布:\")\n",
"# Statistical analysis\n",
"print(f\"\\n📈 Label distribution:\")\n",
"print(region_data['label'].value_counts())\n",
"print(f\"\\n🔬 FS类型分布:\")\n",
"print(f\"\\n🔬 FS type distribution:\")\n",
"print(region_data['FS_type'].value_counts())"
]
},
@ -185,9 +178,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. FScanR 分析 - 从 BLASTX 识别潜在 PRF 位点\n",
"## 2. FScanR Analysis - Identify Potential PRF Sites from BLASTX\n",
"\n",
"使用 FScanR 算法分析 BLASTX 结果,识别潜在的程序性核糖体移码位点。"
"Use the FScanR algorithm to analyze BLASTX results and identify potential programmed ribosomal frameshift sites."
]
},
{
@ -229,9 +222,9 @@
}
],
"source": [
"# 运行FScanR分析\n",
"print(\"🔍 运行FScanR分析...\")\n",
"print(\"参数设置: mismatch_cutoff=10, evalue_cutoff=1e-5, frameDist_cutoff=10\")\n",
"# Run FScanR analysis\n",
"print(\"🔍 Running FScanR analysis...\")\n",
"print(\"Parameter settings: mismatch_cutoff=10, evalue_cutoff=1e-5, frameDist_cutoff=100\")\n",
"\n",
"fscanr_results = fscanr(\n",
" blastx_data,\n",
@ -240,29 +233,29 @@
" frameDist_cutoff=100\n",
")\n",
"\n",
"print(f\"\\n✅ FScanR分析完成!\")\n",
"print(f\"检测到的潜在PRF位点数量: {len(fscanr_results)}\")\n",
"print(f\"\\n✅ FScanR analysis complete!\")\n",
"print(f\"Number of potential PRF sites detected: {len(fscanr_results)}\")\n",
"\n",
"if len(fscanr_results) > 0:\n",
" print(f\"\\n📊 FScanR结果概览:\")\n",
" print(f\" 列名: {list(fscanr_results.columns)}\")\n",
" print(f\" 涉及的序列数: {fscanr_results['DNA_seqid'].nunique()}\")\n",
" print(f\" 链方向分布: {fscanr_results['Strand'].value_counts().to_dict()}\")\n",
" print(f\" FS类型分布: {fscanr_results['FS_type'].value_counts().to_dict()}\")\n",
" print(f\"\\n📊 FScanR results overview:\")\n",
" print(f\" Column names: {list(fscanr_results.columns)}\")\n",
" print(f\" Number of sequences involved: {fscanr_results['DNA_seqid'].nunique()}\")\n",
" print(f\" Strand orientation distribution: {fscanr_results['Strand'].value_counts().to_dict()}\")\n",
" print(f\" FS type distribution: {fscanr_results['FS_type'].value_counts().to_dict()}\")\n",
" \n",
" print(\"\\n🎯 FScanR结果示例:\")\n",
" print(\"\\n🎯 FScanR results examples:\")\n",
" print(fscanr_results.head())\n",
"else:\n",
" print(\"⚠️ 未检测到PRF位点可能需要调整参数\")"
" print(\"⚠️ No PRF sites detected, may need to adjust parameters\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. 序列提取 - 获取 PRF 位点周围序列\n",
"## 3. Sequence Extraction - Extract Sequences Around PRF Sites\n",
"\n",
"从 mRNA 序列中提取 FScanR 识别的 PRF 位点周围的序列片段。"
"Extract sequence fragments around PRF sites identified by FScanR from mRNA sequences."
]
},
{
@ -306,36 +299,36 @@
}
],
"source": [
"# 提取PRF位点周围的序列\n",
"# Extract sequences around PRF sites\n",
"if len(fscanr_results) > 0:\n",
" print(\"📝 从mRNA序列中提取PRF位点周围序列...\")\n",
" print(\"📝 Extracting sequences around PRF sites from mRNA sequences...\")\n",
" \n",
" prf_sequences = extract_prf_regions(\n",
" mrna_file=mrna_file,\n",
" prf_data=fscanr_results\n",
" )\n",
" \n",
" print(f\"\\n✅ 序列提取完成!\")\n",
" print(f\"成功提取的序列数量: {len(prf_sequences)}\")\n",
" print(f\"\\n✅ Sequence extraction complete!\")\n",
" print(f\"Number of successfully extracted sequences: {len(prf_sequences)}\")\n",
" \n",
" if len(prf_sequences) > 0:\n",
" print(f\"\\n📏 序列长度验证:\")\n",
" print(f\"\\n📏 Sequence length validation:\")\n",
" seq_lengths = prf_sequences['399bp'].str.len()\n",
" print(f\" 399bp序列长度分布: {seq_lengths.value_counts().to_dict()}\")\n",
" print(f\" 平均长度: {seq_lengths.mean():.1f}\")\n",
" print(f\" 399bp sequence length distribution: {seq_lengths.value_counts().to_dict()}\")\n",
" print(f\" Average length: {seq_lengths.mean():.1f}\")\n",
" \n",
" print(\"\\n🧬 提取序列示例:\")\n",
" print(\"\\n🧬 Extracted sequence examples:\")\n",
" for i, row in prf_sequences.head(3).iterrows():\n",
" print(f\"序列 {i+1}: {row['DNA_seqid']}\")\n",
" print(f\" FS位置: {row['FS_start']}-{row['FS_end']}\")\n",
" print(f\" 链方向: {row['Strand']}\")\n",
" print(f\" FS类型: {row['FS_type']}\")\n",
" print(f\" 序列片段: {row['399bp'][:50]}...{row['399bp'][-20:]}\")\n",
" print(f\"Sequence {i+1}: {row['DNA_seqid']}\")\n",
" print(f\" FS position: {row['FS_start']}-{row['FS_end']}\")\n",
" print(f\" Strand orientation: {row['Strand']}\")\n",
" print(f\" FS type: {row['FS_type']}\")\n",
" print(f\" Sequence fragment: {row['399bp'][:50]}...{row['399bp'][-20:]}\")\n",
" print()\n",
" else:\n",
" print(\"❌ 序列提取失败\")\n",
" print(\"❌ Sequence extraction failed\")\n",
"else:\n",
" print(\"⚠️ 跳过序列提取 - 无FScanR结果\")\n",
" print(\"⚠️ Skipping sequence extraction - no FScanR results\")\n",
" prf_sequences = pd.DataFrame()"
]
},
@ -343,9 +336,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. FScanpy 预测 - 机器学习模型分析\n",
"## 4. FScanpy Prediction - Machine Learning Model Analysis\n",
"\n",
"使用 FScanpy 的机器学习模型对提取的序列进行 PRF 概率预测。"
"Use FScanpy's machine learning models to predict PRF probabilities for the extracted sequences."
]
},
{
@ -379,26 +372,26 @@
}
],
"source": [
"# 初始化预测器\n",
"# Initialize predictor\n",
"predictor = PRFPredictor()\n",
"print(\"🤖 FScanpy预测器初始化完成\")\n",
"print(\"🤖 FScanpy predictor initialization complete\")\n",
"\n",
"# 对FScanR识别的序列进行预测\n",
"# Predict FScanR identified sequences\n",
"if len(prf_sequences) > 0:\n",
" print(f\"\\n🎯 对 {len(prf_sequences)} 个FScanR识别的序列进行预测...\")\n",
" print(f\"\\n🎯 Predicting {len(prf_sequences)} sequences identified by FScanR...\")\n",
" \n",
" fscanr_predictions = predictor.predict_regions(\n",
" sequences=prf_sequences['399bp'],\n",
" ensemble_weight=0.4 # 平衡配置\n",
" ensemble_weight=0.4 # Balanced configuration\n",
" )\n",
" \n",
" # 合并结果\n",
" # Merge results\n",
" fscanr_predictions = pd.concat([\n",
" prf_sequences.reset_index(drop=True),\n",
" fscanr_predictions.reset_index(drop=True)\n",
" ], axis=1)\n",
" \n",
" print(\"\\n📊 FScanR+FScanpy预测结果:\")\n",
" print(\"\\n📊 FScanR+FScanpy prediction results:\")\n",
" result_cols = ['DNA_seqid', 'FS_start', 'FS_type', 'Short_Probability', 'Long_Probability', 'Ensemble_Probability']\n",
" print(fscanr_predictions[result_cols].head())"
]
@ -429,15 +422,15 @@
}
],
"source": [
"# 对验证区域数据进行预测\n",
"print(f\"\\n🧪 对 {len(region_data)} 个验证区域进行预测...\")\n",
"# Predict validation region data\n",
"print(f\"\\n🧪 Predicting {len(region_data)} validation regions...\")\n",
"\n",
"validation_predictions = predict_prf(\n",
" data=region_data.rename(columns={'399bp': 'Long_Sequence'}),\n",
" ensemble_weight=0.4\n",
")\n",
"\n",
"print(\"\\n📊 验证区域预测结果:\")\n",
"print(\"\\n📊 Validation region prediction results:\")\n",
"result_cols = ['DNA_seqid', 'label', 'source', 'Short_Probability', 'Long_Probability', 'Ensemble_Probability']\n",
"print(validation_predictions[result_cols].head())"
]
@ -446,9 +439,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. 序列级预测和可视化\n",
"## 5. Sequence-level Prediction and Visualization\n",
"\n",
"选择一个具体的mRNA序列使用内置的plot_prf_prediction函数进行完整的滑动窗口预测和可视化。"
"Select a specific mRNA sequence and use the built-in plot_prf_prediction function for complete sliding window prediction and visualization."
]
},
{
@ -557,19 +550,19 @@
}
],
"source": [
"# 选择一个序列进行演示\n",
"# Select a sequence for demonstration\n",
"from Bio import SeqIO\n",
"\n",
"# 读取第一个mRNA序列作为演示\n",
"# Read the first mRNA sequence for demonstration\n",
"mrna_sequences = list(SeqIO.parse(mrna_file, \"fasta\"))\n",
"demo_seq = mrna_sequences[0] # 选择第一个序列\n",
"demo_seq = mrna_sequences[0] # Select the first sequence\n",
"\n",
"print(f\"🧬 选择演示序列: {demo_seq.id}\")\n",
"print(f\"序列长度: {len(demo_seq.seq)} bp\")\n",
"print(f\"序列前100bp: {str(demo_seq.seq)[:100]}...\")\n",
"print(f\"🧬 Selected demonstration sequence: {demo_seq.id}\")\n",
"print(f\"Sequence length: {len(demo_seq.seq)} bp\")\n",
"print(f\"First 100bp of sequence: {str(demo_seq.seq)[:100]}...\")\n",
"\n",
"# 使用内置的plot_prf_prediction函数进行预测和可视化\n",
"print(f\"\\n🎯 使用plot_prf_prediction进行序列预测和可视化...\")\n",
"# Use built-in plot_prf_prediction function for prediction and visualization\n",
"print(f\"\\n🎯 Using plot_prf_prediction for sequence prediction and visualization...\")\n",
"\n",
"sequence_results, fig = plot_prf_prediction(\n",
" sequence=str(demo_seq.seq),\n",
@ -577,18 +570,18 @@
" short_threshold=0.2,\n",
" long_threshold=0.2,\n",
" ensemble_weight=0.6,\n",
" title=f\"序列 {demo_seq.id} 的PRF预测结果条形图+热图)\",\n",
" title=f\"PRF Prediction Results for Sequence {demo_seq.id} (Bar Chart + Heatmap)\",\n",
" figsize=(16, 8),\n",
" dpi=150\n",
")\n",
"\n",
"plt.show()\n",
"\n",
"print(f\"\\n📊 序列预测结果统计:\")\n",
"print(f\" 预测位点总数: {len(sequence_results)}\")\n",
"print(f\" 高概率位点 (>0.8): {(sequence_results['Ensemble_Probability'] > 0.8).sum()}\")\n",
"print(f\" 中概率位点 (0.4-0.8): {((sequence_results['Ensemble_Probability'] >= 0.4) & (sequence_results['Ensemble_Probability'] <= 0.8)).sum()}\")\n",
"print(f\" 最高预测概率: {sequence_results['Ensemble_Probability'].max():.3f}\")"
"print(f\"\\n📊 Sequence prediction result statistics:\")\n",
"print(f\" Total predicted sites: {len(sequence_results)}\")\n",
"print(f\" High probability sites (>0.8): {(sequence_results['Ensemble_Probability'] > 0.8).sum()}\")\n",
"print(f\" Medium probability sites (0.4-0.8): {((sequence_results['Ensemble_Probability'] >= 0.4) & (sequence_results['Ensemble_Probability'] <= 0.8)).sum()}\")\n",
"print(f\" Highest prediction probability: {sequence_results['Ensemble_Probability'].max():.3f}\")"
]
},
{
@ -634,53 +627,53 @@
}
],
"source": [
"# 打印Top预测位点的概率\n",
"# Print top predicted site probabilities\n",
"if sequence_results['Ensemble_Probability'].max() > 0.3:\n",
" top_predictions = sequence_results.nlargest(5, 'Ensemble_Probability')\n",
" print(f\"\\n🔝 Top 5 预测位点:\")\n",
" print(f\"\\n🔝 Top 5 predicted sites:\")\n",
" for i, (_, row) in enumerate(top_predictions.iterrows(), 1):\n",
" print(f\" {i}. 位置 {row['Position']}: \")\n",
" print(f\" - Short概率: {row['Short_Probability']:.3f}\")\n",
" print(f\" - Long概率: {row['Long_Probability']:.3f}\")\n",
" print(f\" - 集成概率: {row['Ensemble_Probability']:.3f}\")\n",
" print(f\" - 密码子: {row['Codon']}\")\n",
" print(f\" {i}. Position {row['Position']}: \")\n",
" print(f\" - Short probability: {row['Short_Probability']:.3f}\")\n",
" print(f\" - Long probability: {row['Long_Probability']:.3f}\")\n",
" print(f\" - Ensemble probability: {row['Ensemble_Probability']:.3f}\")\n",
" print(f\" - Codon: {row['Codon']}\")\n",
"else:\n",
" print(\"\\n💡 该序列没有检测到高概率的PRF位点\")\n",
" print(\"\\n💡 No high-probability PRF sites detected in this sequence\")\n",
"\n",
"print(\"\\n📊 可视化分析完成!\")\n",
"print(\"图表包含热图和条形图展示了整个序列的PRF预测概率分布。\")"
"print(\"\\n📊 Visualization analysis complete!\")\n",
"print(\"The chart contains heatmaps and bar charts showing the PRF prediction probability distribution across the entire sequence.\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 📝 分析总结\n",
"## 📝 Analysis Summary\n",
"\n",
"### 🎯 主要发现\n",
"1. **数据质量**: 测试数据集包含真实的BLASTX比对结果和验证区域\n",
"2. **FScanR效果**: 从BLASTX结果中识别出潜在PRF位点\n",
"3. **模型性能**: Short和Long模型在不同场景下各有优势\n",
"4. **预测结果**: 集成模型提供了更稳定的预测性能\n",
"5. **可视化**: 内置绘图函数生成清晰的热图和条形图\n",
"### 🎯 Key Findings\n",
"1. **Data Quality**: Test dataset contains real BLASTX alignment results and validation regions\n",
"2. **FScanR Performance**: Successfully identified potential PRF sites from BLASTX results\n",
"3. **Model Performance**: Short and Long models each have advantages in different scenarios\n",
"4. **Prediction Results**: Ensemble model provides more stable prediction performance\n",
"5. **Visualization**: Built-in plotting functions generate clear heatmaps and bar charts\n",
"\n",
"### 🔧 最佳实践\n",
"- **数据预处理**: 确保BLASTX结果格式正确\n",
"- **参数设置**: 使用默认的集成权重(0.4:0.6)获得平衡性能\n",
"- **结果解读**: 在使用FScanpy对整条序列进行预测时,不应该使用0.5作为阈值,而应该比较不同位置的概率高低\n",
"- **可视化**: 使用plot_prf_prediction函数生成标准化图表\n",
"### 🔧 Best Practices\n",
"- **Data Preprocessing**: Ensure BLASTX results are in correct format\n",
"- **Parameter Settings**: Use default ensemble weights (0.4:0.6) for balanced performance\n",
"- **Result Interpretation**: When using FScanpy for whole sequence prediction, don't use 0.5 as threshold, but compare relative probabilities across positions\n",
"- **Visualization**: Use plot_prf_prediction function to generate standardized plots\n",
"\n",
"### 📚 使用建议\n",
"1. **阈值选择**: 根据应用场景调整概率阈值\n",
"2. **结果验证**: 结合生物学知识验证预测结果\n",
"3. **性能优化**: 对于大规模数据使用合理的滑动窗口大小\n",
"4. **可视化参数**: 调整figsize和dpi获得最佳显示效果"
"### 📚 Usage Recommendations\n",
"1. **Threshold Selection**: Adjust probability thresholds based on application scenarios\n",
"2. **Result Validation**: Validate prediction results with biological knowledge\n",
"3. **Performance Optimization**: Use reasonable sliding window sizes for large-scale data\n",
"4. **Visualization Parameters**: Adjust figsize and dpi for optimal display"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "tf200",
"language": "python",
"name": "python3"
},
@ -694,7 +687,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
"version": "3.9.0"
}
},
"nbformat": 4,