ipynb文件创建,绘图函数完善

This commit is contained in:
Chenlab 2025-05-29 17:58:48 +08:00
parent 7587bd632b
commit 3833704c47
15 changed files with 1899 additions and 427 deletions

View File

@ -1,14 +0,0 @@
Metadata-Version: 2.4
Name: FScanpy
Version: 1.0.0
Summary: PRF prediction tool
Author: FScanpy Developer
Author-email: FScanpy Developer <example@example.com>
Requires-Python: >=3.7
Requires-Dist: numpy
Requires-Dist: pandas
Requires-Dist: tensorflow
Requires-Dist: scikit-learn
Requires-Dist: wrapt>=1.10.11
Dynamic: author
Dynamic: requires-python

View File

@ -1,14 +0,0 @@
README.md
pyproject.toml
setup.py
FScanpy/__init__.py
FScanpy/predictor.py
FScanpy/utils.py
FScanpy.egg-info/PKG-INFO
FScanpy.egg-info/SOURCES.txt
FScanpy.egg-info/dependency_links.txt
FScanpy.egg-info/requires.txt
FScanpy.egg-info/top_level.txt
FScanpy/features/__init__.py
FScanpy/features/cnn_input.py
FScanpy/features/sequence.py

View File

@ -1 +0,0 @@

View File

@ -1,5 +0,0 @@
numpy
pandas
tensorflow
scikit-learn
wrapt>=1.10.11

View File

@ -1 +0,0 @@
FScanpy

View File

@ -1,4 +1,5 @@
from .predictor import PRFPredictor
from . import data
import pandas as pd
import numpy as np
from typing import Union, List, Dict
@ -7,13 +8,14 @@ __version__ = '0.3.0'
__author__ = ''
__email__ = ''
__all__ = ['PRFPredictor', 'predict_prf', '__version__', '__author__', '__email__']
__all__ = ['PRFPredictor', 'predict_prf', 'plot_prf_prediction', 'data', '__version__', '__author__', '__email__']
def predict_prf(
sequence: Union[str, List[str], None] = None,
data: Union[pd.DataFrame, None] = None,
window_size: int = 3,
gb_threshold: float = 0.1,
short_threshold: float = 0.1,
ensemble_weight: float = 0.4,
model_dir: str = None
) -> pd.DataFrame:
"""
@ -21,13 +23,18 @@ def predict_prf(
Args:
sequence: 单个或多个DNA序列用于滑动窗口预测
data: DataFrame数据必须包含'399bp'用于区域预测
data: DataFrame数据必须包含'Long_Sequence''399bp'用于区域预测
window_size: 滑动窗口大小默认为3
gb_threshold: GB模型概率阈值默认为0.1
short_threshold: Short模型(HistGB)概率阈值默认为0.1
ensemble_weight: Short模型在集成中的权重默认为0.4Long权重为0.6
model_dir: 模型文件目录路径可选
Returns:
pandas.DataFrame: 预测结果
pandas.DataFrame: 预测结果包含以下主要字段
- Short_Probability: Short模型预测概率
- Long_Probability: Long模型预测概率
- Ensemble_Probability: 集成预测概率主要结果
- Ensemble_Weights: 权重配置信息
Examples:
# 1. 单条序列滑动窗口预测
@ -39,10 +46,13 @@ def predict_prf(
>>> sequences = ["ATGCGTACGT...", "GCTATAGCAT..."]
>>> results = predict_prf(sequence=sequences)
# 3. DataFrame区域预测
# 3. 自定义集成权重比例
>>> results = predict_prf(sequence=sequence, ensemble_weight=0.3) # 3:7 比例
# 4. DataFrame区域预测
>>> import pandas as pd
>>> data = pd.DataFrame({
... '399bp': ['ATGCGT...', 'GCTATAG...']
... 'Long_Sequence': ['ATGCGT...', 'GCTATAG...'] # 或使用 '399bp'
... })
>>> results = predict_prf(data=data)
"""
@ -53,20 +63,22 @@ def predict_prf(
raise ValueError("必须提供sequence或data参数之一")
if sequence is not None and data is not None:
raise ValueError("sequence和data参数不能同时提供")
if not (0.0 <= ensemble_weight <= 1.0):
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
# 滑动窗口预测模式
if sequence is not None:
if isinstance(sequence, str):
# 单条序列预测
return predictor.predict_full(
sequence, window_size, gb_threshold)
return predictor.predict_sequence(
sequence, window_size, short_threshold, ensemble_weight)
elif isinstance(sequence, (list, tuple)):
# 多条序列预测
results = []
for i, seq in enumerate(sequence, 1):
try:
result = predictor.predict_full(
seq, window_size, gb_threshold)
result = predictor.predict_sequence(
seq, window_size, short_threshold, ensemble_weight)
result['Sequence_ID'] = f'seq_{i}'
results.append(result)
except Exception as e:
@ -78,17 +90,23 @@ def predict_prf(
if not isinstance(data, pd.DataFrame):
raise ValueError("data参数必须是pandas DataFrame类型")
if '399bp' not in data.columns:
raise ValueError("DataFrame必须包含'399bp'")
# 检查列名(支持新旧两种命名)
seq_column = None
if 'Long_Sequence' in data.columns:
seq_column = 'Long_Sequence'
elif '399bp' in data.columns:
seq_column = '399bp'
else:
raise ValueError("DataFrame必须包含'Long_Sequence''399bp'")
# 调用区域预测函数
try:
results = predictor.predict_region(
data['399bp'], gb_threshold)
results = predictor.predict_regions(
data[seq_column], short_threshold, ensemble_weight)
# 添加原始数据的其他列
for col in data.columns:
if col not in ['399bp', '33bp']:
if col not in ['Long_Sequence', '399bp', 'Short_Sequence', '33bp']:
results[col] = data[col].values
return results
@ -96,14 +114,92 @@ def predict_prf(
except Exception as e:
print(f"警告:区域预测失败 - {str(e)}")
# 创建空结果
long_weight = 1.0 - ensemble_weight
results = pd.DataFrame({
'GB_Probability': [0.0] * len(data),
'CNN_Probability': [0.0] * len(data),
'Voting_Probability': [0.0] * len(data)
'Short_Probability': [0.0] * len(data),
'Long_Probability': [0.0] * len(data),
'Ensemble_Probability': [0.0] * len(data),
'Ensemble_Weights': [f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}'] * len(data)
})
# 添加原始数据列
for col in data.columns:
results[col] = data[col].values
return results
return results
def plot_prf_prediction(
sequence: str,
window_size: int = 3,
short_threshold: float = 0.65,
long_threshold: float = 0.8,
ensemble_weight: float = 0.4,
title: str = None,
save_path: str = None,
figsize: tuple = (12, 6),
dpi: int = 300,
model_dir: str = None
) -> tuple:
"""
绘制序列PRF预测结果的移码概率图
Args:
sequence: 输入DNA序列
window_size: 滑动窗口大小默认为3
short_threshold: Short模型(HistGB)过滤阈值默认为0.65
long_threshold: Long模型(BiLSTM-CNN)过滤阈值默认为0.8
ensemble_weight: Short模型在集成中的权重默认为0.4Long权重为0.6
title: 图片标题可选
save_path: 保存路径可选如果提供则保存图片
figsize: 图片尺寸默认为(12, 6)
dpi: 图片分辨率默认为300
model_dir: 模型文件目录路径可选
Returns:
tuple: (pd.DataFrame, matplotlib.figure.Figure) 预测结果和图形对象
Examples:
# 1. 简单绘图
>>> from FScanpy import plot_prf_prediction
>>> sequence = "ATGCGTACGT..."
>>> results, fig = plot_prf_prediction(sequence)
>>> plt.show()
# 2. 自定义阈值和集成权重
>>> results, fig = plot_prf_prediction(
... sequence,
... short_threshold=0.7,
... long_threshold=0.85,
... ensemble_weight=0.3, # 3:7 权重比例
... title="自定义权重的预测结果",
... save_path="prediction_result.png"
... )
# 3. 等权重组合
>>> results, fig = plot_prf_prediction(
... sequence,
... ensemble_weight=0.5 # 5:5 等权重
... )
# 4. Long模型主导
>>> results, fig = plot_prf_prediction(
... sequence,
... ensemble_weight=0.2 # 2:8 权重Long模型主导
... )
"""
if not (0.0 <= ensemble_weight <= 1.0):
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
predictor = PRFPredictor(model_dir=model_dir)
return predictor.plot_sequence_prediction(
sequence=sequence,
window_size=window_size,
short_threshold=short_threshold,
long_threshold=long_threshold,
ensemble_weight=ensemble_weight,
title=title,
save_path=save_path,
figsize=figsize,
dpi=dpi
)

View File

@ -1,9 +1,120 @@
"""
FScanpy数据模块
提供测试数据访问和处理功能
"""
import os
import pkg_resources
from pathlib import Path
from typing import List
def get_test_data_path(filename: str) -> str:
return pkg_resources.resource_filename('FScanpy', f'data/test_data/{filename}')
"""
获取测试数据文件的完整路径
Args:
filename: 测试数据文件名
Returns:
str: 文件的完整路径
Examples:
>>> from FScanpy.data import get_test_data_path
>>> blastx_file = get_test_data_path('blastx_example.xlsx')
>>> mrna_file = get_test_data_path('mrna_example.fasta')
>>> region_file = get_test_data_path('region_example.csv')
"""
current_dir = Path(__file__).parent
test_data_dir = current_dir / "test_data"
file_path = test_data_dir / filename
if not file_path.exists():
available_files = list_test_data()
raise FileNotFoundError(
f"测试数据文件不存在: {filename}\n"
f"可用的测试数据文件: {available_files}"
)
return str(file_path)
def list_test_data() -> list:
data_dir = pkg_resources.resource_filename('FScanpy', 'data/test_data')
return os.listdir(data_dir)
def list_test_data() -> List[str]:
"""
列出所有可用的测试数据文件
Returns:
List[str]: 测试数据文件名列表
Examples:
>>> from FScanpy.data import list_test_data
>>> files = list_test_data()
>>> print(files)
['blastx_example.xlsx', 'mrna_example.fasta', 'region_example.csv']
"""
try:
current_dir = Path(__file__).parent
test_data_dir = current_dir / "test_data"
if not test_data_dir.exists():
return []
files = []
for file_path in test_data_dir.iterdir():
if file_path.is_file() and not file_path.name.startswith('.'):
files.append(file_path.name)
return sorted(files)
except Exception:
return []
def print_test_data_info():
"""
打印测试数据的详细信息
"""
print("📋 FScanpy 测试数据信息:")
print("=" * 50)
try:
current_dir = Path(__file__).parent
test_data_dir = current_dir / "test_data"
if not test_data_dir.exists():
print("❌ 测试数据目录不存在")
return
files = list_test_data()
if not files:
print("❌ 没有找到测试数据文件")
return
print(f"📁 数据目录: {test_data_dir}")
print(f"📊 文件数量: {len(files)}")
print()
file_descriptions = {
'blastx_example.xlsx': '🧬 BLASTX比对结果示例 (1000条记录)',
'mrna_example.fasta': '🧬 mRNA序列示例数据',
'region_example.csv': '🎯 PRF区域验证数据 (含标签)'
}
for filename in files:
file_path = test_data_dir / filename
size_mb = file_path.stat().st_size / (1024 * 1024)
description = file_descriptions.get(filename, '📄 数据文件')
print(f" {description}")
print(f" 文件名: {filename}")
print(f" 大小: {size_mb:.2f} MB")
print(f" 路径: {file_path}")
print()
print("🚀 使用示例:")
print(" from FScanpy.data import get_test_data_path")
print(" blastx_file = get_test_data_path('blastx_example.xlsx')")
print(" mrna_file = get_test_data_path('mrna_example.fasta')")
print(" region_file = get_test_data_path('region_example.csv')")
except Exception as e:
print(f"❌ 获取数据信息时出错: {e}")
# 导出主要函数
__all__ = ['get_test_data_path', 'list_test_data', 'print_test_data_info']

View File

@ -1,8 +1,4 @@
FS_period,399bp,fs_position,DNA_seqid,label,source,FS_type,dataset
gtgtgaacacaatagtgagtgacatactaaacg,ggatatgtaacatggacaagtcattgtgtaggtatccaagaccaatagcctttactttcaaagggaaagaagaattcaggaccttgtttggaggactcatctcgatgtcgattcaggtggtcattgtgctctatgcttatattatgctaaagataatgatagaacgtaatgacacatcaaaaagtgtgaacacaatagtgagtgacatactaaacgacaaatctccagtatctctcaatacaacagatttctcgtttgcatttgatgcttttattcttggcgatgataatttcgatttcaacaataaccaatacttcggaattgagctacttcaatggattaagcagccagatactggagaactatcatccactaatattccatatgaaagatgtggaa,16.0,MSTRG.18491.1,0,EUPLOTES,negative,EUPLOTES
gtctcagaagagtctgaggaatatctccaagga,caaattaataacaaatatgaattccatcaacaacttttatggagacgagaacttatcagatgaacttctgagtgaagatgtcgtgtcttgagaagtaagaggatcagaaaagatcttgcataacatggggagaaagtctctcagtaataagaagcctttaagcggagtggagttggactgagagtctcagaagagtctgaggaatatctccaaggataaaatttgttcgcaaggaagatctatctttaggcagaagaagtcaaaatcttgtgatcaagtagaagaacctcttagtagtcttaaagataacatgagtcactttaatgacatagacttgcaagctagtaagcctctaaaatcagagattagcaatctttttgggtactcaactcagcccaa,16.0,MSTRG.4662.1,0,EUPLOTES,negative,EUPLOTES
cttacttgcaaacatgaatctaataaattagag,ttaagaaggcataagagttttgctaaaaataaagatttgaagaatattactactaagtttggcaagagtaaacagagaagaagtaccatttctggctctccgacaaaatcagtcagatgcccttctgcaaaaaagagcctaacagatagaccaagaagaggaggtatccttgccaggaagaatcttacttgcaaacatgaatctaataaattagagatgctgatgaacctcatctatcgtacaccgaatgtagacctgattgaaaataggatcgatggactgataagaagtaactctatattgaacaaagtcgagaagagagtagctcactccggcattaagacttacaggttttctcctaatttactgaagaagataattccaaagaagataaaattc,16.0,MSTRG.14742.1,0,EUPLOTES,negative,EUPLOTES
atagtagaagatacagtctccatggccagtaca,ttgttgataaaaatacaggattttatctccaattaaagtcgagaaaataagccaagtagcagtaaacgtagataaacgaataagcatggaaaacaataatagaatagaccaaattgaagagatttctcacaattcgtatttgaattactgttagatttaggaataacaaatcgaactcgacacatagtagaagatacagtctccatggccagtacaaaagccaccatgaagaagaggaaaaagaacagataaatgctatcaaagaaaataaagttaacgatcaatatacaagcatgattgttattaatttaaccaactttagataatcagattataatcagcgacgatcaaagcaaaaaggtaaagaacaaccaattcaaattgaacatcagaaaggga,16.0,CUFF.17967.2,0,EUPLOTES,negative,EUPLOTES
cctcgtctttgtctccagaaaataagaaaacaa,catcaataaatagagtcaatgttagaagtatgtctaaattcaaaccaaacgaaattctaaataaagcaagaatgccaacataattaatttagattaagcttgatagttcattagtactcaacaataaaaatatttcaaagggtaatattccagaatcaaaattaagaaataaaattattcctacctcgtctttgtctccagaaaataagaaaacaaataaatcagttatgttcgaaaatgttaaagagatggaaagccaggacaagtcgcaaaatacactaacacatttgaaagaaagcaataatggtagtccttccaaattttaaaactgaaaataatcttgcagatgtagttcgatctagagataataaagcttataacagtactctaaacttaaaa,16.0,CUFF.22392.1,0,EUPLOTES,negative,EUPLOTES
aaaaatgacaaagatctgaacattagttctttc,ttaattttgttctgatcacctaattgtaagcccaaaaacgatactcaaaagatgaggaaactttattggaacattaaaagtaatctcttgagtttatttatgctaactatttacatacgaagcttttacgaaacattccaattcttggctcttgctggcttatcagctacttggaacaacgacaaaaatgacaaagatctgaacattagttctttcatttttgccattgttttgttgtttttatgcacaggtttcttcttatggtcactctaccattactttggatcccgctctgacaatcctcgaaatctcaaaatctctcaggagtttacgaatggagcaaaggagaataatagcggtaaactatatccagtgcttggattgctgagaagaggtctc,16.0,MSTRG.9455.1,0,EUPLOTES,negative,EUPLOTES
agaagactgggagaactctcagatactatatct,agaagaagagaaggccaggagtagctcgaaagaggaggaatttaaggtttacccaaagaaccctatgactgactctaaagatgatcagtcggacactctccctccgaaatcttacagtgtaaagaaagccaatgtaggagaactaaacaagtacgattttgagatctcttattccaaataatgagaagactgggagaactctcagatactatatctgcaagtatgatgaatgaaggcgtaaatttaacaagacttggaactttattgatcacgctaggatacacacaggagagaagccttacaaatgtgagctgtgtggcaaagagtttgctcagaaggggaactacaacaaacacaggaatacccaccagcatagtgccaagaagacctcagtaatga,16.0,MSTRG.26803.1,0,EUPLOTES,negative,EUPLOTES
cttacttgcaaacatgaatctaataaattagag,ttaagaaggcataagagttttgctaaaaataaagatttgaagaatattactactaagtttggcaagagtaaacagagaagaagtaccatttctggctctccgacaaaatcagtcagatgcccttctgcaaaaaagagcctaacagatagaccaagaagaggaggtatccttgccaggaagaatcttacttgcaaacatgaatctaataaattagagatgctgatgaacctcatctatcgtacaccgaatgtagacctgattgaaaataggatcgatggactgataagaagtaactctatattgaacaaagtcgagaagagagtagctcactccggcattaagacttacaggttttctcctaatttactgaagaagataattccaaagaagataaaattc,16.0,MSTRG.14742.1,0,EUPLOTES,negative,EUPLOTES
1 FS_period 399bp fs_position DNA_seqid label source FS_type dataset
2 gtgtgaacacaatagtgagtgacatactaaacg ggatatgtaacatggacaagtcattgtgtaggtatccaagaccaatagcctttactttcaaagggaaagaagaattcaggaccttgtttggaggactcatctcgatgtcgattcaggtggtcattgtgctctatgcttatattatgctaaagataatgatagaacgtaatgacacatcaaaaagtgtgaacacaatagtgagtgacatactaaacgacaaatctccagtatctctcaatacaacagatttctcgtttgcatttgatgcttttattcttggcgatgataatttcgatttcaacaataaccaatacttcggaattgagctacttcaatggattaagcagccagatactggagaactatcatccactaatattccatatgaaagatgtggaa 16.0 MSTRG.18491.1 0 EUPLOTES negative EUPLOTES
3 gtctcagaagagtctgaggaatatctccaagga caaattaataacaaatatgaattccatcaacaacttttatggagacgagaacttatcagatgaacttctgagtgaagatgtcgtgtcttgagaagtaagaggatcagaaaagatcttgcataacatggggagaaagtctctcagtaataagaagcctttaagcggagtggagttggactgagagtctcagaagagtctgaggaatatctccaaggataaaatttgttcgcaaggaagatctatctttaggcagaagaagtcaaaatcttgtgatcaagtagaagaacctcttagtagtcttaaagataacatgagtcactttaatgacatagacttgcaagctagtaagcctctaaaatcagagattagcaatctttttgggtactcaactcagcccaa 16.0 MSTRG.4662.1 0 EUPLOTES negative EUPLOTES
4 cttacttgcaaacatgaatctaataaattagag ttaagaaggcataagagttttgctaaaaataaagatttgaagaatattactactaagtttggcaagagtaaacagagaagaagtaccatttctggctctccgacaaaatcagtcagatgcccttctgcaaaaaagagcctaacagatagaccaagaagaggaggtatccttgccaggaagaatcttacttgcaaacatgaatctaataaattagagatgctgatgaacctcatctatcgtacaccgaatgtagacctgattgaaaataggatcgatggactgataagaagtaactctatattgaacaaagtcgagaagagagtagctcactccggcattaagacttacaggttttctcctaatttactgaagaagataattccaaagaagataaaattc 16.0 MSTRG.14742.1 0 EUPLOTES negative EUPLOTES
atagtagaagatacagtctccatggccagtaca ttgttgataaaaatacaggattttatctccaattaaagtcgagaaaataagccaagtagcagtaaacgtagataaacgaataagcatggaaaacaataatagaatagaccaaattgaagagatttctcacaattcgtatttgaattactgttagatttaggaataacaaatcgaactcgacacatagtagaagatacagtctccatggccagtacaaaagccaccatgaagaagaggaaaaagaacagataaatgctatcaaagaaaataaagttaacgatcaatatacaagcatgattgttattaatttaaccaactttagataatcagattataatcagcgacgatcaaagcaaaaaggtaaagaacaaccaattcaaattgaacatcagaaaggga 16.0 CUFF.17967.2 0 EUPLOTES negative EUPLOTES
cctcgtctttgtctccagaaaataagaaaacaa catcaataaatagagtcaatgttagaagtatgtctaaattcaaaccaaacgaaattctaaataaagcaagaatgccaacataattaatttagattaagcttgatagttcattagtactcaacaataaaaatatttcaaagggtaatattccagaatcaaaattaagaaataaaattattcctacctcgtctttgtctccagaaaataagaaaacaaataaatcagttatgttcgaaaatgttaaagagatggaaagccaggacaagtcgcaaaatacactaacacatttgaaagaaagcaataatggtagtccttccaaattttaaaactgaaaataatcttgcagatgtagttcgatctagagataataaagcttataacagtactctaaacttaaaa 16.0 CUFF.22392.1 0 EUPLOTES negative EUPLOTES
aaaaatgacaaagatctgaacattagttctttc ttaattttgttctgatcacctaattgtaagcccaaaaacgatactcaaaagatgaggaaactttattggaacattaaaagtaatctcttgagtttatttatgctaactatttacatacgaagcttttacgaaacattccaattcttggctcttgctggcttatcagctacttggaacaacgacaaaaatgacaaagatctgaacattagttctttcatttttgccattgttttgttgtttttatgcacaggtttcttcttatggtcactctaccattactttggatcccgctctgacaatcctcgaaatctcaaaatctctcaggagtttacgaatggagcaaaggagaataatagcggtaaactatatccagtgcttggattgctgagaagaggtctc 16.0 MSTRG.9455.1 0 EUPLOTES negative EUPLOTES
agaagactgggagaactctcagatactatatct agaagaagagaaggccaggagtagctcgaaagaggaggaatttaaggtttacccaaagaaccctatgactgactctaaagatgatcagtcggacactctccctccgaaatcttacagtgtaaagaaagccaatgtaggagaactaaacaagtacgattttgagatctcttattccaaataatgagaagactgggagaactctcagatactatatctgcaagtatgatgaatgaaggcgtaaatttaacaagacttggaactttattgatcacgctaggatacacacaggagagaagccttacaaatgtgagctgtgtggcaaagagtttgctcagaaggggaactacaacaaacacaggaatacccaccagcatagtgccaagaagacctcagtaatga 16.0 MSTRG.26803.1 0 EUPLOTES negative EUPLOTES

View File

@ -24,163 +24,170 @@ class PRFPredictor:
model_dir = resource_filename('FScanpy', 'pretrained')
try:
# 加载模型
self.gb_model = self._load_pickle(os.path.join(model_dir, 'GradientBoosting_all.pkl'))
self.cnn_model = self._load_pickle(os.path.join(model_dir, 'BiLSTM-CNN_all.pkl'))
# 加载模型 - 使用新的命名约定
self.short_model = self._load_pickle(os.path.join(model_dir, 'short.pkl')) # HistGB模型
self.long_model = self._load_pickle(os.path.join(model_dir, 'long.pkl')) # BiLSTM-CNN模型
# 初始化特征提取器和CNN处理器使用与训练时相同的序列长度
self.gb_seq_length = 33 # HistGradientBoosting使用的序列长度
self.cnn_seq_length = 399 # BiLSTM-CNN使用的序列长度
self.short_seq_length = 33 # HistGB使用的序列长度
self.long_seq_length = 399 # BiLSTM-CNN使用的序列长度
# 初始化特征提取器和CNN输入处理器
self.feature_extractor = SequenceFeatureExtractor(seq_length=self.gb_seq_length)
self.cnn_processor = CNNInputProcessor(max_length=self.cnn_seq_length)
self.feature_extractor = SequenceFeatureExtractor(seq_length=self.short_seq_length)
self.cnn_processor = CNNInputProcessor(max_length=self.long_seq_length)
# 检测模型类型以优化预测性能
self._detect_model_types()
except FileNotFoundError as e:
raise FileNotFoundError(f"无法找到模型文件: {str(e)}")
raise FileNotFoundError(f"无法找到模型文件: {str(e)}。请确保模型文件 'short.pkl''long.pkl' 存在于 {model_dir}")
except Exception as e:
raise Exception(f"加载模型出错: {str(e)}")
def _load_pickle(self, path):
return joblib.load(path)
"""安全加载pickle文件"""
try:
return joblib.load(path)
except Exception as e:
raise FileNotFoundError(f"无法加载模型文件 {path}: {str(e)}")
def predict_single_position(self, fs_period, full_seq, gb_threshold=0.1):
def _detect_model_types(self):
"""检测模型类型以优化预测性能"""
self.short_is_sklearn = hasattr(self.short_model, 'predict_proba')
self.long_is_sklearn = hasattr(self.long_model, 'predict_proba')
def _predict_model(self, model, features, is_sklearn, seq_length):
"""统一的模型预测方法"""
try:
if is_sklearn:
# sklearn模型使用特征向量
if isinstance(features, np.ndarray) and features.ndim > 1:
features = features.flatten()
features_2d = np.array([features])
pred = model.predict_proba(features_2d)
return pred[0][1]
else:
# 深度学习模型
if seq_length == self.long_seq_length:
# 对于长序列使用CNN处理器
model_input = self.cnn_processor.prepare_sequence(features)
else:
# 对于短序列,转换为数值编码
base_to_num = {'A': 1, 'T': 2, 'G': 3, 'C': 4, 'N': 0}
seq_numeric = [base_to_num.get(base, 0) for base in features.upper()]
model_input = np.array(seq_numeric).reshape(1, len(seq_numeric), 1)
# 统一的预测调用
try:
pred = model.predict(model_input, verbose=0)
except TypeError:
pred = model.predict(model_input)
# 处理预测结果
if isinstance(pred, list):
pred = pred[0]
if hasattr(pred, 'shape') and len(pred.shape) > 1 and pred.shape[1] > 1:
return pred[0][1]
else:
return pred[0][0] if hasattr(pred[0], '__getitem__') else pred[0]
except Exception as e:
raise Exception(f"模型预测失败: {str(e)}")
def predict_single_position(self, fs_period, full_seq, short_threshold=0.1, ensemble_weight=0.4):
'''
预测单个位置的PRF状态
Args:
fs_period: 33bp序列 (将根据gb_seq_length处理)
full_seq: 完整序列 (将根据cnn_seq_length处理)
gb_threshold: GB模型的概率阈值 (默认为0.1)
fs_period: 33bp序列 (short模型使用)
full_seq: 完整序列 (long模型使用)
short_threshold: short模型的概率阈值 (默认为0.1)
ensemble_weight: short模型在集成中的权重 (默认为0.4long权重为0.6)
Returns:
dict: 包含预测概率的字典
'''
try:
# 处理序列长度
if len(fs_period) > self.gb_seq_length:
fs_period = self.feature_extractor.trim_sequence(fs_period, self.gb_seq_length)
# GB模型预测 - 确保输入是二维数组
try:
gb_features = self.feature_extractor.extract_features(fs_period)
# 检查特征结构并确保是一维数组
if isinstance(gb_features, np.ndarray):
# 如果是多维数组,进行扁平化处理
if gb_features.ndim > 1:
print(f"警告: 特征是{gb_features.ndim}维数组,进行扁平化处理")
gb_features = gb_features.flatten()
# 明确将特征转换为二维数组,正确形状为(1, n_features)
gb_features_2d = np.array([gb_features])
# 再次检查维度
if gb_features_2d.ndim != 2:
raise ValueError(f"处理后特征仍为{gb_features_2d.ndim}维,需要二维数组")
gb_prob = self.gb_model.predict_proba(gb_features_2d)[0][1]
except Exception as e:
print(f"GB模型预测时出错: {str(e)}")
# 出错时设置概率为0
gb_prob = 0.0
# 验证权重参数
if not (0.0 <= ensemble_weight <= 1.0):
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
# 如果GB概率低于阈值则跳过CNN模型
if gb_prob < gb_threshold:
long_weight = 1.0 - ensemble_weight
# 处理序列长度
if len(fs_period) > self.short_seq_length:
fs_period = self.feature_extractor.trim_sequence(fs_period, self.short_seq_length)
# Short模型预测 (HistGB)
try:
if self.short_is_sklearn:
short_features = self.feature_extractor.extract_features(fs_period)
short_prob = self._predict_model(self.short_model, short_features, True, self.short_seq_length)
else:
short_prob = self._predict_model(self.short_model, fs_period, False, self.short_seq_length)
except Exception as e:
print(f"Short模型预测时出错: {str(e)}")
short_prob = 0.0
# 如果short概率低于阈值则跳过long模型
if short_prob < short_threshold:
return {
'GB_Probability': gb_prob,
'CNN_Probability': 0.0,
'Voting_Probability': 0.0
'Short_Probability': short_prob,
'Long_Probability': 0.0,
'Ensemble_Probability': 0.0,
'Ensemble_Weights': f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}'
}
# CNN模型预测
# Long模型预测 (BiLSTM-CNN)
try:
# 首先检查CNN模型的类型 - 通过尝试识别模型类型
is_sklearn_model = False
# 检测模型类型的方法
if hasattr(self.cnn_model, 'predict_proba'):
# 这可能是一个scikit-learn模型
is_sklearn_model = True
if is_sklearn_model:
# 如果是sklearn模型 (如HistGradientBoostingClassifier)使用与GB相同的特征提取
# 为CNN模型使用相同的特征提取方法但从399bp序列中提取
cnn_features = self.feature_extractor.extract_features(full_seq)
if isinstance(cnn_features, np.ndarray) and cnn_features.ndim > 1:
cnn_features = cnn_features.flatten()
# 转为二维数组
cnn_features_2d = np.array([cnn_features])
cnn_pred = self.cnn_model.predict_proba(cnn_features_2d)
cnn_prob = cnn_pred[0][1]
if self.long_is_sklearn:
long_features = self.feature_extractor.extract_features(full_seq)
long_prob = self._predict_model(self.long_model, long_features, True, self.long_seq_length)
else:
# 假设是深度学习模型,需要三维输入
cnn_input = self.cnn_processor.prepare_sequence(full_seq)
# 尝试不同的预测方法
try:
# 先尝试不带参数
cnn_pred = self.cnn_model.predict(cnn_input)
except TypeError:
try:
# 再尝试带verbose参数
cnn_pred = self.cnn_model.predict(cnn_input, verbose=0)
except Exception:
# 最后尝试将输入重塑为2D
reshaped_input = cnn_input.reshape(1, -1)
cnn_pred = self.cnn_model.predict(reshaped_input)
# 处理预测结果
if isinstance(cnn_pred, list):
cnn_pred = cnn_pred[0]
# 提取概率值
if hasattr(cnn_pred, 'shape') and len(cnn_pred.shape) > 1 and cnn_pred.shape[1] > 1:
cnn_prob = cnn_pred[0][1]
else:
cnn_prob = cnn_pred[0][0] if hasattr(cnn_pred[0], '__getitem__') else cnn_pred[0]
long_prob = self._predict_model(self.long_model, full_seq, False, self.long_seq_length)
except Exception as e:
print(f"CNN模型预测时出错: {str(e)}")
# 出错时设置概率为0
cnn_prob = 0.0
print(f"Long模型预测时出错: {str(e)}")
long_prob = 0.0
# 使用4:6的加权平均替代投票模型
# 计算集成概率
try:
voting_prob = 0.4 * gb_prob + 0.6 * cnn_prob
ensemble_prob = ensemble_weight * short_prob + long_weight * long_prob
except Exception as e:
print(f"计算加权平均时出错: {str(e)}")
# 出错时使用简单平均
voting_prob = (gb_prob + cnn_prob) / 2
print(f"计算集成概率时出错: {str(e)}")
ensemble_prob = (short_prob + long_prob) / 2
return {
'GB_Probability': gb_prob,
'CNN_Probability': cnn_prob,
'Voting_Probability': voting_prob
'Short_Probability': short_prob,
'Long_Probability': long_prob,
'Ensemble_Probability': ensemble_prob,
'Ensemble_Weights': f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}'
}
except Exception as e:
raise Exception(f"预测过程出错: {str(e)}")
def predict_full(self, sequence, window_size=3, gb_threshold=0.1, plot=False):
def predict_sequence(self, sequence, window_size=3, short_threshold=0.1, ensemble_weight=0.4):
"""
预测完整序列中的PRF位点
预测完整序列中的PRF位点滑动窗口方法
Args:
sequence: 输入DNA序列
window_size: 滑动窗口大小 (默认为3)
gb_threshold: GB模型概率阈值 (默认为0.1)
plot: 是否绘制预测结果图表 (默认为False)
short_threshold: short模型概率阈值 (默认为0.1)
ensemble_weight: short模型在集成中的权重 (默认为0.4)
Returns:
if plot=False:
pd.DataFrame: 包含预测结果的DataFrame
if plot=True:
tuple: (pd.DataFrame, matplotlib.figure.Figure)
pd.DataFrame: 包含预测结果的DataFrame
"""
if window_size < 1:
raise ValueError("窗口大小必须大于等于1")
if gb_threshold < 0:
raise ValueError("GB阈值必须大于等于0")
if short_threshold < 0:
raise ValueError("short模型阈值必须大于等于0")
if not (0.0 <= ensemble_weight <= 1.0):
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
results = []
long_weight = 1.0 - ensemble_weight
try:
# 确保序列为字符串并转换为大写
@ -188,210 +195,207 @@ class PRFPredictor:
# 滑动窗口预测
for pos in range(0, len(sequence) - 2, window_size):
# 提取窗口序列 - 使用与训练时相同的窗口大小
# 提取窗口序列
fs_period, full_seq = extract_window_sequences(sequence, pos)
if fs_period is None or full_seq is None:
continue
# 预测并记录结果
pred = self.predict_single_position(fs_period, full_seq, gb_threshold)
pred = self.predict_single_position(fs_period, full_seq, short_threshold, ensemble_weight)
pred.update({
'Position': pos,
'Codon': sequence[pos:pos+3],
'33bp': fs_period,
'399bp': full_seq
'Short_Sequence': fs_period, # 更清晰的命名
'Long_Sequence': full_seq # 更清晰的命名
})
results.append(pred)
# 创建结果DataFrame
results_df = pd.DataFrame(results)
# 如需绘图
if plot:
# 创建图形
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 10), height_ratios=[2, 1])
# 绘制折线图
ax1.plot(results_df['Position'], results_df['GB_Probability'],
label='GB模型', alpha=0.7, linewidth=1.5)
ax1.plot(results_df['Position'], results_df['CNN_Probability'],
label='CNN模型', alpha=0.7, linewidth=1.5)
ax1.plot(results_df['Position'], results_df['Voting_Probability'],
label='投票模型', linewidth=2, color='red')
ax1.set_xlabel('序列位置')
ax1.set_ylabel('移码概率')
ax1.set_title('移码预测概率')
ax1.legend()
ax1.grid(True, alpha=0.3)
# 准备热图数据
positions = results_df['Position'].values
probabilities = results_df['Voting_Probability'].values
# 创建热图矩阵
heatmap_matrix = np.zeros((1, len(positions)))
heatmap_matrix[0, :] = probabilities
# 绘制热图
im = ax2.imshow(heatmap_matrix, aspect='auto', cmap='YlOrRd',
extent=[min(positions), max(positions), 0, 1])
# 添加颜色条
cbar = plt.colorbar(im, ax=ax2)
cbar.set_label('移码概率')
# 设置热图轴标签
ax2.set_xlabel('序列位置')
ax2.set_title('移码概率热图')
ax2.set_yticks([])
# 调整布局
plt.tight_layout()
return results_df, fig
return results_df
except Exception as e:
raise Exception(f"序列预测过程出错: {str(e)}")
def predict_region(self, seq, gb_threshold=0.1):
'''
预测区域序列
def plot_sequence_prediction(self, sequence, window_size=3, short_threshold=0.65,
long_threshold=0.8, ensemble_weight=0.4, title=None, save_path=None,
figsize=(12, 6), dpi=300):
"""
绘制序列预测结果的移码概率图
Args:
seq: 399bp序列或包含399bp序列的DataFrame/Series
gb_threshold: GB模型概率阈值 (默认为0.1)
sequence: 输入DNA序列
window_size: 滑动窗口大小 (默认为3)
short_threshold: Short模型(HistGB)过滤阈值 (默认为0.65)
long_threshold: Long模型(BiLSTM-CNN)过滤阈值 (默认为0.8)
ensemble_weight: Short模型在集成中的权重 (默认为0.4)
title: 图片标题 (可选)
save_path: 保存路径 (可选如果提供则保存图片)
figsize: 图片尺寸 (默认为(12, 6))
dpi: 图片分辨率 (默认为300)
Returns:
tuple: (pd.DataFrame, matplotlib.figure.Figure) 预测结果和图形对象
"""
try:
# 验证权重参数
if not (0.0 <= ensemble_weight <= 1.0):
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
long_weight = 1.0 - ensemble_weight
# 获取预测结果 - 使用新的方法名
results_df = self.predict_sequence(sequence, window_size=window_size,
short_threshold=0.1, ensemble_weight=ensemble_weight)
if results_df.empty:
raise ValueError("预测结果为空,请检查输入序列")
# 获取序列长度
seq_length = len(sequence)
# 计算显示宽度
prob_width = max(1, seq_length // 300) # 概率标记的宽度
# 创建图形,包含两个子图
fig = plt.figure(figsize=figsize)
gs = fig.add_gridspec(2, 1, height_ratios=[0.15, 1], hspace=0.3)
# 设置标题
if title:
fig.suptitle(title, y=0.95, fontsize=12)
else:
fig.suptitle(f'序列移码概率预测结果 (权重 {ensemble_weight:.1f}:{long_weight:.1f})', y=0.95, fontsize=12)
# 预测概率热图
ax0 = fig.add_subplot(gs[0])
prob_data = np.zeros((1, seq_length))
# 应用双重阈值过滤
for _, row in results_df.iterrows():
pos = int(row['Position'])
if (row['Short_Probability'] >= short_threshold and
row['Long_Probability'] >= long_threshold):
# 为每个满足阈值的位置设置概率值
start = max(0, pos - prob_width//2)
end = min(seq_length, pos + prob_width//2 + 1)
prob_data[0, start:end] = row['Ensemble_Probability']
im = ax0.imshow(prob_data, cmap='Reds', aspect='auto', vmin=0, vmax=1,
interpolation='nearest')
ax0.set_xticks([])
ax0.set_yticks([])
ax0.set_title(f'预测概率热图 (Short≥{short_threshold}, Long≥{long_threshold})',
pad=5, fontsize=10)
# 主图(条形图)
ax1 = fig.add_subplot(gs[1])
# 应用过滤阈值
filtered_probs = results_df['Ensemble_Probability'].copy()
mask = ((results_df['Short_Probability'] < short_threshold) |
(results_df['Long_Probability'] < long_threshold))
filtered_probs[mask] = 0
# 绘制条形图
bars = ax1.bar(results_df['Position'], filtered_probs,
alpha=0.7, color='darkred', width=max(1, window_size))
# 设置x轴刻度
step = max(seq_length // 10, 50)
x_ticks = np.arange(0, seq_length, step)
ax1.set_xticks(x_ticks)
ax1.tick_params(axis='x', rotation=45)
# 设置标签和标题
ax1.set_xlabel('序列位置 (bp)', fontsize=10)
ax1.set_ylabel('移码概率', fontsize=10)
ax1.set_title(f'移码概率分布 (集成权重 {ensemble_weight:.1f}:{long_weight:.1f})', fontsize=11)
# 设置y轴范围
ax1.set_ylim(0, 1)
# 添加网格
ax1.grid(True, alpha=0.3)
# 添加阈值和权重说明
info_text = (f'过滤阈值: Short≥{short_threshold}, Long≥{long_threshold}\n'
f'集成权重: Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}')
ax1.text(0.02, 0.95, info_text, transform=ax1.transAxes,
fontsize=9, verticalalignment='top',
bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8))
# 确保所有子图的x轴范围一致
for ax in [ax0, ax1]:
ax.set_xlim(-1, seq_length)
# 调整布局
plt.tight_layout()
# 如果提供了保存路径,则保存图片
if save_path:
plt.savefig(save_path, dpi=dpi, bbox_inches='tight')
# 同时保存PDF版本
if save_path.endswith('.png'):
pdf_path = save_path.replace('.png', '.pdf')
plt.savefig(pdf_path, bbox_inches='tight')
print(f"图片已保存至: {save_path}")
return results_df, fig
except Exception as e:
raise Exception(f"绘制序列预测图时出错: {str(e)}")
def predict_regions(self, sequences, short_threshold=0.1, ensemble_weight=0.4):
'''
预测区域序列批量预测已知的399bp序列
Args:
sequences: 399bp序列或包含399bp序列的DataFrame/Series/list
short_threshold: short模型概率阈值 (默认为0.1)
ensemble_weight: short模型在集成中的权重 (默认为0.4)
Returns:
DataFrame: 包含所有序列预测概率的DataFrame
'''
try:
# 如果输入是DataFrame或Series转换为列表
if isinstance(seq, (pd.DataFrame, pd.Series)):
seq = seq.tolist()
# 验证权重参数
if not (0.0 <= ensemble_weight <= 1.0):
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
# 如果输入是单个字符串,转换为列表
if isinstance(seq, str):
seq = [seq]
# 统一输入格式
if isinstance(sequences, (pd.DataFrame, pd.Series)):
sequences = sequences.tolist()
elif isinstance(sequences, str):
sequences = [sequences]
results = []
for i, seq399 in enumerate(seq):
for i, seq399 in enumerate(sequences):
try:
# 从399bp序列中截取中心的33bp (GB模型使用)
seq33 = self._extract_center_sequence(seq399, target_length=self.gb_seq_length)
# 从399bp序列中截取中心的33bp (short模型使用)
seq33 = self._extract_center_sequence(seq399, target_length=self.short_seq_length)
# GB模型预测 - 确保输入是二维数组
try:
gb_features = self.feature_extractor.extract_features(seq33)
# 检查特征结构并确保是一维数组
if isinstance(gb_features, np.ndarray):
# 如果是多维数组,进行扁平化处理
if gb_features.ndim > 1:
print(f"警告: 序列 {i+1} 的特征是{gb_features.ndim}维数组,进行扁平化处理")
gb_features = gb_features.flatten()
# 明确将特征转换为二维数组,正确形状为(1, n_features)
gb_features_2d = np.array([gb_features])
# 再次检查维度
if gb_features_2d.ndim != 2:
raise ValueError(f"处理后特征仍为{gb_features_2d.ndim}维,需要二维数组")
gb_prob = self.gb_model.predict_proba(gb_features_2d)[0][1]
except Exception as e:
print(f"GB模型预测序列 {i+1} 时出错: {str(e)}")
# 出错时设置概率为0
gb_prob = 0.0
# 如果GB概率低于阈值添加低概率结果
if gb_prob < gb_threshold:
results.append({
'GB_Probability': gb_prob,
'CNN_Probability': 0.0,
'Voting_Probability': 0.0,
'33bp': seq33,
'399bp': seq399
})
continue
# CNN模型预测
try:
# 首先检查CNN模型的类型 - 通过尝试识别模型类型
is_sklearn_model = False
# 检测模型类型的方法
if hasattr(self.cnn_model, 'predict_proba'):
# 这可能是一个scikit-learn模型
is_sklearn_model = True
if is_sklearn_model:
# 如果是sklearn模型 (如HistGradientBoostingClassifier)使用与GB相同的特征提取
# 为CNN模型使用相同的特征提取方法但从399bp序列中提取
cnn_features = self.feature_extractor.extract_features(seq399)
if isinstance(cnn_features, np.ndarray) and cnn_features.ndim > 1:
cnn_features = cnn_features.flatten()
# 转为二维数组
cnn_features_2d = np.array([cnn_features])
cnn_pred = self.cnn_model.predict_proba(cnn_features_2d)
cnn_prob = cnn_pred[0][1]
else:
# 假设是深度学习模型,需要三维输入
cnn_input = self.cnn_processor.prepare_sequence(seq399)
# 尝试不同的预测方法
try:
# 先尝试不带参数
cnn_pred = self.cnn_model.predict(cnn_input)
except TypeError:
try:
# 再尝试带verbose参数
cnn_pred = self.cnn_model.predict(cnn_input, verbose=0)
except Exception:
# 最后尝试将输入重塑为2D
reshaped_input = cnn_input.reshape(1, -1)
cnn_pred = self.cnn_model.predict(reshaped_input)
# 处理预测结果
if isinstance(cnn_pred, list):
cnn_pred = cnn_pred[0]
# 提取概率值
if hasattr(cnn_pred, 'shape') and len(cnn_pred.shape) > 1 and cnn_pred.shape[1] > 1:
cnn_prob = cnn_pred[0][1]
else:
cnn_prob = cnn_pred[0][0] if hasattr(cnn_pred[0], '__getitem__') else cnn_pred[0]
except Exception as e:
print(f"CNN模型预测序列 {i+1} 时出错: {str(e)}")
# 出错时设置概率为0
cnn_prob = 0.0
# 使用4:6的加权平均替代投票模型
try:
voting_prob = 0.4 * gb_prob + 0.6 * cnn_prob
except Exception as e:
print(f"计算加权平均时出错: {str(e)}")
# 出错时使用简单平均
voting_prob = (gb_prob + cnn_prob) / 2
results.append({
'GB_Probability': gb_prob,
'CNN_Probability': cnn_prob,
'Voting_Probability': voting_prob,
'33bp': seq33,
'399bp': seq399
# 使用统一的预测方法
pred_result = self.predict_single_position(seq33, seq399, short_threshold, ensemble_weight)
pred_result.update({
'Short_Sequence': seq33,
'Long_Sequence': seq399
})
results.append(pred_result)
except Exception as e:
print(f"处理第 {i+1} 个序列时出错: {str(e)}")
long_weight = 1.0 - ensemble_weight
results.append({
'GB_Probability': 0.0,
'CNN_Probability': 0.0,
'Voting_Probability': 0.0,
'33bp': self._extract_center_sequence(seq399, target_length=self.gb_seq_length) if len(seq399) >= self.gb_seq_length else seq399,
'399bp': seq399
'Short_Probability': 0.0,
'Long_Probability': 0.0,
'Ensemble_Probability': 0.0,
'Ensemble_Weights': f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}',
'Short_Sequence': self._extract_center_sequence(seq399, target_length=self.short_seq_length) if len(seq399) >= self.short_seq_length else seq399,
'Long_Sequence': seq399
})
return pd.DataFrame(results)
@ -424,4 +428,60 @@ class PRFPredictor:
end = len(sequence)
start = end - target_length
return sequence[start:end]
return sequence[start:end]
# 兼容性方法(向后兼容,但标记为废弃)
def predict_full(self, sequence, window_size=3, short_threshold=0.1, short_weight=0.4, plot=False):
"""
已废弃请使用 predict_sequence() 方法
向后兼容的方法内部调用新的 predict_sequence()
"""
import warnings
warnings.warn("predict_full() 已废弃,请使用 predict_sequence() 方法", DeprecationWarning, stacklevel=2)
# 调用新方法并添加兼容性字段
results_df = self.predict_sequence(sequence, window_size, short_threshold, short_weight)
# 添加兼容性字段
if 'Ensemble_Probability' in results_df.columns:
results_df['Voting_Probability'] = results_df['Ensemble_Probability']
results_df['Weighted_Probability'] = results_df['Ensemble_Probability']
if 'Ensemble_Weights' in results_df.columns:
results_df['Weight_Info'] = results_df['Ensemble_Weights']
if 'Short_Sequence' in results_df.columns:
results_df['33bp'] = results_df['Short_Sequence']
if 'Long_Sequence' in results_df.columns:
results_df['399bp'] = results_df['Long_Sequence']
if plot:
# 如果需要绘图,调用绘图方法
_, fig = self.plot_sequence_prediction(sequence, window_size, 0.65, 0.8, short_weight)
return results_df, fig
return results_df
def predict_region(self, seq, short_threshold=0.1, short_weight=0.4):
"""
已废弃请使用 predict_regions() 方法
向后兼容的方法内部调用新的 predict_regions()
"""
import warnings
warnings.warn("predict_region() 已废弃,请使用 predict_regions() 方法", DeprecationWarning, stacklevel=2)
# 调用新方法并添加兼容性字段
results_df = self.predict_regions(seq, short_threshold, short_weight)
# 添加兼容性字段
if 'Ensemble_Probability' in results_df.columns:
results_df['Voting_Probability'] = results_df['Ensemble_Probability']
results_df['Weighted_Probability'] = results_df['Ensemble_Probability']
if 'Ensemble_Weights' in results_df.columns:
results_df['Weight_Info'] = results_df['Ensemble_Weights']
if 'Short_Sequence' in results_df.columns:
results_df['33bp'] = results_df['Short_Sequence']
if 'Long_Sequence' in results_df.columns:
results_df['399bp'] = results_df['Long_Sequence']
return results_df

702
FScanpy_Demo.ipynb Normal file

File diff suppressed because one or more lines are too long

220
README.md
View File

@ -1,25 +1,149 @@
# FScanpy
## A Machine Learning-Based Framework for Programmed Ribosomal Frameshifting Prediction
FScanpy is a comprehensive Python package designed for the prediction of [Programmed Ribosomal Frameshifting (PRF)](https://en.wikipedia.org/wiki/Ribosomal_frameshift) sites in nucleotide sequences. By integrating advanced machine learning approaches (Gradient Boosting and BiLSTM-CNN) with the established [FScanR](https://github.com/seanchen607/FScanR.git) framework, FScanpy provides robust and accurate PRF site predictions. The package requires input sequences to be in the positive (5' to 3') orientation.
FScanpy is a comprehensive Python package designed for the prediction of [Programmed Ribosomal Frameshifting (PRF)](https://en.wikipedia.org/wiki/Ribosomal_frameshift) sites in nucleotide sequences. By integrating advanced machine learning approaches (HistGradientBoosting and BiLSTM-CNN) with the established [FScanR](https://github.com/seanchen607/FScanR.git) framework, FScanpy provides robust and accurate PRF site predictions. The package requires input sequences to be in the positive (5' to 3') orientation.
![FScanpy Architecture](/tutorial/image/structure.jpeg)
For detailed documentation and usage examples, please refer to our [tutorial](tutorial/tutorial.md).
## 🚀 What's New in v0.3.0
### Model Naming Optimization
- **Short Model** (`short.pkl`): HistGradientBoosting model for rapid screening
- **Long Model** (`long.pkl`): BiLSTM-CNN model for detailed analysis
- **Unified Interface**: Consistent parameter naming and clearer output fields
### Performance Improvements
- **Faster Prediction**: Optimized model type detection and reduced redundant operations
- **Better Error Handling**: More informative error messages and robust exception handling
- **Code Quality**: Reduced code duplication and improved maintainability
### 🎨 New Visualization Features
- **Sequence Plotting**: Built-in function for visualizing PRF prediction results
- **Dual Threshold Filtering**: Separate filtering for Short and Long models
- **Interactive Graphics**: Heatmap and bar chart visualization
- **Export Options**: Support for PNG and PDF output formats
### ⚖️ Ensemble Weighting System
- **Flexible Ensemble**: Control the contribution of Short and Long models
- **Weight Validation**: Automatic parameter validation and error handling
- **Clear Naming**: `ensemble_weight` parameter for intuitive usage
- **Visual Feedback**: Weight ratios displayed in plots and results
### 🔧 API Improvements
- **Method Renaming**: More intuitive method names
- `predict_sequence()`: Replaces `predict_full()` for sequence prediction
- `predict_regions()`: Replaces `predict_region()` for batch prediction
- **Field Standardization**: Consistent output field naming
- `Ensemble_Probability`: Main prediction result (replaces `Voting_Probability`)
- `Short_Sequence` / `Long_Sequence`: Clear sequence field names
- **Backward Compatibility**: Deprecated methods still work with warnings
## Core Features
- **Sequence Feature Extraction**: Support for extracting features from nucleic acid sequences, including base composition, k - mer features, and positional features.
- **Sequence Feature Extraction**: Support for extracting features from nucleic acid sequences, including base composition, k-mer features, and positional features.
- **Frameshift Hotspot Region Prediction**: Predict potential PRF sites in nucleotide sequences using machine learning models.
- **Feature Extraction**: Extract relevant features from sequences to assist in prediction.
- **Cross - Species Support**: Built - in databases for viruses, marine phages, Euplotes, etc., enabling PRF prediction across various species.
- **Cross-Species Support**: Built-in databases for viruses, marine phages, Euplotes, etc., enabling PRF prediction across various species.
- **Visualization Tools**: Built-in plotting functions for result visualization and analysis.
- **Ensemble Modeling**: Customizable ensemble weights for different prediction strategies.
## Main Advantages
- **High Accuracy**: Integrates multiple machine learning models to provide accurate PRF site predictions.
- **Efficiency**: Utilizes a sliding window approach and feature extraction techniques to rapidly scan sequences.
- **Versatility**: Supports PRF prediction across various species and can be combined with the [FScanR](https://github.com/seanchen607/FScanR.git) framework for enhanced accuracy.
- **User - Friendly**: Comes with detailed documentation and usage examples, making it easy for researchers to use.
- **User-Friendly**: Comes with detailed documentation and usage examples, making it easy for researchers to use.
- **Flexible**: Provides different resolutions to suit different using situations.
## Quick Start
### Basic Prediction
```python
from FScanpy import predict_prf
# Single sequence prediction with default ensemble weights (0.4:0.6)
sequence = "ATGCGTACGT..."
results = predict_prf(sequence=sequence)
print(results[['Position', 'Short_Probability', 'Long_Probability', 'Ensemble_Probability']].head())
```
### Custom Ensemble Weighting
```python
# Adjust model weights for different prediction strategies
results_long_dominant = predict_prf(sequence=sequence, ensemble_weight=0.3) # 3:7 ratio (Long dominant)
results_equal_weight = predict_prf(sequence=sequence, ensemble_weight=0.5) # 5:5 ratio (Equal weight)
results_short_dominant = predict_prf(sequence=sequence, ensemble_weight=0.7) # 7:3 ratio (Short dominant)
# Compare ensemble probabilities
print("Long dominant:", results_long_dominant['Ensemble_Probability'].mean())
print("Equal weight:", results_equal_weight['Ensemble_Probability'].mean())
print("Short dominant:", results_short_dominant['Ensemble_Probability'].mean())
```
### Visualization with Custom Weights
```python
from FScanpy import plot_prf_prediction
import matplotlib.pyplot as plt
# Generate prediction plot with custom ensemble weighting
sequence = "ATGCGTACGT..."
results, fig = plot_prf_prediction(
sequence=sequence,
short_threshold=0.65, # HistGB threshold
long_threshold=0.8, # BiLSTM-CNN threshold
ensemble_weight=0.3, # Custom weight: 30% Short, 70% Long
title="Long-Dominant Ensemble PRF Prediction (3:7)",
save_path="prediction_result.png"
)
plt.show()
```
### Advanced Usage with New API
```python
from FScanpy import PRFPredictor
import matplotlib.pyplot as plt
# Create predictor instance
predictor = PRFPredictor()
# Use new sequence prediction method
results = predictor.predict_sequence(
sequence=sequence,
ensemble_weight=0.4
)
# Compare different ensemble configurations
weights = [0.2, 0.4, 0.6, 0.8]
weight_names = ["Long 80%", "Balanced", "Short 60%", "Short 80%"]
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
axes = axes.flatten()
for i, (weight, name) in enumerate(zip(weights, weight_names)):
results = predictor.predict_sequence(sequence=sequence, ensemble_weight=weight)
ax = axes[i]
ax.bar(results['Position'], results['Ensemble_Probability'], alpha=0.7)
ax.set_title(f'{name} (Weight: {weight:.1f}:{1-weight:.1f})')
ax.set_ylabel('Probability')
plt.tight_layout()
plt.show()
```
### Batch Region Prediction
```python
# Predict multiple 399bp sequences
import pandas as pd
data = pd.DataFrame({
'Long_Sequence': ['ATGCGT...' * 60, 'GCTATAG...' * 57] # 399bp sequences
})
results = predict_prf(data=data, ensemble_weight=0.4)
print(results[['Ensemble_Probability', 'Ensemble_Weights']].head())
```
## Installation Requirements
- Python ≥ 3.7
- Dependencies are automatically handled during installation
@ -36,6 +160,94 @@ cd FScanpy-package
pip install -e .
```
## 🔄 Migration from Previous Versions
### API Changes Summary
```python
# OLD API (deprecated but still works)
results = predict_prf(sequence="ATGC...", short_weight=0.4)
results = predictor.predict_full(sequence, short_weight=0.4)
results = predictor.predict_region(sequences, short_weight=0.4)
# NEW API (recommended)
results = predict_prf(sequence="ATGC...", ensemble_weight=0.4)
results = predictor.predict_sequence(sequence, ensemble_weight=0.4)
results = predictor.predict_regions(sequences, ensemble_weight=0.4)
# Output field changes
# OLD: 'Voting_Probability', 'Weight_Info', '33bp', '399bp'
# NEW: 'Ensemble_Probability', 'Ensemble_Weights', 'Short_Sequence', 'Long_Sequence'
# Visualization with ensemble weights
results, fig = plot_prf_prediction(
sequence="ATGC...",
short_threshold=0.65,
long_threshold=0.8,
ensemble_weight=0.3 # 30% Short, 70% Long
)
```
### Backward Compatibility
- All old methods still work but will show deprecation warnings
- Old field names are automatically added for compatibility
- Gradual migration is supported
## Ensemble Weight Configuration Guide
### Recommended Weights for Different Scenarios:
| Scenario | ensemble_weight | Description | Use Case |
|----------|----------------|-------------|----------|
| **High Sensitivity** | 0.2-0.3 | Long model dominant | Detecting subtle PRF sites |
| **Balanced Detection** | 0.4-0.5 | Balanced ensemble (recommended) | General purpose prediction |
| **Fast Screening** | 0.6-0.7 | Short model dominant | Rapid initial screening |
| **Equal Contribution** | 0.5 | Equal weight to both models | Comparative analysis |
### Weight Selection Guidelines:
- **Low ensemble_weight (0.2-0.3)**:
- Emphasizes Long model (BiLSTM-CNN)
- Better for detecting complex patterns
- Higher sensitivity, may have more false positives
- **High ensemble_weight (0.6-0.8)**:
- Emphasizes Short model (HistGB)
- Faster computation
- Good for initial screening
- Higher specificity, may miss subtle sites
- **Balanced (0.4-0.5)**:
- Recommended for most applications
- Good balance of sensitivity and specificity
- Suitable for comprehensive analysis
## Output Field Reference
### Main Prediction Fields
- **`Short_Probability`**: HistGradientBoosting model prediction (0-1)
- **`Long_Probability`**: BiLSTM-CNN model prediction (0-1)
- **`Ensemble_Probability`**: Final ensemble prediction (primary result)
- **`Ensemble_Weights`**: Weight configuration information
### Sequence Fields
- **`Short_Sequence`**: 33bp sequence used by Short model
- **`Long_Sequence`**: 399bp sequence used by Long model
- **`Position`**: Position in the original sequence
- **`Codon`**: 3bp codon at the position
### Metadata Fields
- **`Sequence_ID`**: Identifier for multi-sequence predictions
- Additional fields from input DataFrame (for region predictions)
## Examples
See `example_plot_prediction.py` for comprehensive examples of:
- Basic prediction plotting
- Custom threshold configuration
- Ensemble weight parameter usage and comparison
- New API method demonstrations
- Saving plots to files
- Advanced visualization options
## Authors

362
example_plot_prediction.py Normal file
View File

@ -0,0 +1,362 @@
#!/usr/bin/env python3
"""
FScanpy 序列预测绘图示例
展示如何使用新的 plot_prf_prediction 函数绘制序列的移码概率预测结果
包含集成权重参数的使用示例
"""
import matplotlib.pyplot as plt
import os
from FScanpy import plot_prf_prediction, PRFPredictor
def example_basic_plotting():
"""基础绘图示例"""
print("=" * 50)
print("基础绘图示例")
print("=" * 50)
# 示例序列(可以替换为您的实际序列)
example_sequence = (
"ATGCGTACGTTAGCGATCGATCGTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGC"
"GATCGATCGTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCGATCGATCGTAGCT"
"AGCTAGCTAGCTAGCTAGCTAGCTAGCGATCGATCGTAGCTAGCTAGCTAGCTAG"
"CTAGCTAGCTAGCGATCGATCGTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGC"
"GATCGATCGTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCGATCGATCGTAGCT"
"AGCTAGCTAGCTAGCTAGCTAGCTAGCGATCGATCGTAGCTAGCTAGCTAGCTAG"
"CTAGCTAGCTAGCGATCGATCGTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGC"
"GATCGATCGTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCGATCGATCGTAGCT"
"AGCTAGCTAGCTAGCTAGCTAGCTAGCGATCGATCGTAGCTAGCTAGCTAGCTAG"
"CTAGCTAGCTAGCGATCGATCGTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGC"
)
try:
# 使用默认参数绘图 (0.4:0.6 集成权重比例)
results, fig = plot_prf_prediction(
sequence=example_sequence,
title="示例序列的移码概率预测 (默认集成权重 4:6)"
)
print(f"预测完成!共处理 {len(results)} 个位置")
print(f"满足阈值条件的位点数: {len(results[results['Ensemble_Probability'] > 0])}")
print(f"使用集成权重比例: Short模型 0.4, Long模型 0.6")
# 显示图片
plt.show()
return results, fig
except Exception as e:
print(f"绘图过程中出错: {str(e)}")
return None, None
def example_custom_ensemble_weights():
"""自定义集成权重示例"""
print("=" * 50)
print("自定义集成权重绘图示例")
print("=" * 50)
# 示例序列
example_sequence = (
"ATGCGTACGTTAGCGATCGATCGTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGC"
"GATCGATCGTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCGATCGATCGTAGCT"
"AGCTAGCTAGCTAGCTAGCTAGCTAGCGATCGATCGTAGCTAGCTAGCTAGCTAG"
"CTAGCTAGCTAGCGATCGATCGTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGC"
"GATCGATCGTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCGATCGATCGTAGCT"
)
# 测试不同的集成权重比例
weight_configs = [
(0.2, "Long模型主导 (2:8)"),
(0.5, "等权重组合 (5:5)"),
(0.7, "Short模型主导 (7:3)")
]
for ensemble_weight, description in weight_configs:
print(f"\n测试集成权重配置: {description}")
try:
results, fig = plot_prf_prediction(
sequence=example_sequence,
ensemble_weight=ensemble_weight,
title=f"移码概率预测 - {description}",
figsize=(14, 7)
)
print(f"预测完成!共处理 {len(results)} 个位置")
print(f"满足阈值条件的位点数: {len(results[results['Ensemble_Probability'] > 0])}")
# 显示统计信息
print("预测统计信息:")
print(f" Short模型平均概率: {results['Short_Probability'].mean():.3f}")
print(f" Long模型平均概率: {results['Long_Probability'].mean():.3f}")
print(f" 集成平均概率: {results['Ensemble_Probability'].mean():.3f}")
print(f" 集成权重比例: Short:{ensemble_weight:.1f}, Long:{1-ensemble_weight:.1f}")
plt.show()
except Exception as e:
print(f"集成权重 {ensemble_weight} 绘图时出错: {str(e)}")
def example_ensemble_comparison():
"""集成权重对比示例"""
print("=" * 50)
print("集成权重对比绘图示例")
print("=" * 50)
# 示例序列
example_sequence = (
"ATGCGTACGTTAGCGATCGATCGTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGC"
"GATCGATCGTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCGATCGATCGTAGCT"
"AGCTAGCTAGCTAGCTAGCTAGCTAGCGATCGATCGTAGCTAGCTAGCTAGCTAG"
"CTAGCTAGCTAGCGATCGATCGTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGC"
)
try:
# 创建预测器实例
predictor = PRFPredictor()
# 测试三种不同集成权重
weights = [0.3, 0.4, 0.6]
weight_names = ["Long主导 (3:7)", "默认权重 (4:6)", "Short主导 (6:4)"]
# 创建对比图
fig, axes = plt.subplots(3, 1, figsize=(15, 12))
fig.suptitle('不同集成权重配置的预测结果对比', fontsize=16)
all_results = []
for i, (weight, name) in enumerate(zip(weights, weight_names)):
# 获取预测结果
results = predictor.predict_sequence(
sequence=example_sequence,
ensemble_weight=weight
)
all_results.append(results)
# 绘制条形图
ax = axes[i]
ax.bar(results['Position'], results['Ensemble_Probability'],
alpha=0.7, color=f'C{i}', width=2)
ax.set_title(f'{name} - 平均概率: {results["Ensemble_Probability"].mean():.3f}')
ax.set_ylabel('概率')
ax.grid(True, alpha=0.3)
ax.set_ylim(0, 1)
if i == len(weights) - 1:
ax.set_xlabel('序列位置')
plt.tight_layout()
plt.show()
# 打印对比统计
print("\n集成权重对比统计:")
for i, (weight, name, results) in enumerate(zip(weights, weight_names, all_results)):
print(f"{name}:")
print(f" 平均集成概率: {results['Ensemble_Probability'].mean():.3f}")
print(f" 最大集成概率: {results['Ensemble_Probability'].max():.3f}")
print(f" 非零预测数量: {(results['Ensemble_Probability'] > 0).sum()}")
return all_results, fig
except Exception as e:
print(f"集成权重对比时出错: {str(e)}")
return None, None
def example_save_plot():
"""保存图片示例"""
print("=" * 50)
print("保存图片示例")
print("=" * 50)
# 创建保存目录
save_dir = "prediction_plots"
os.makedirs(save_dir, exist_ok=True)
# 示例序列
example_sequence = (
"ATGCGTACGTTAGCGATCGATCGTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGC"
"GATCGATCGTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCGATCGATCGTAGCT"
"AGCTAGCTAGCTAGCTAGCTAGCTAGCGATCGATCGTAGCTAGCTAGCTAGCTAG"
"CTAGCTAGCTAGCGATCGATCGTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGC"
)
try:
# 保存不同集成权重配置的图片
weight_configs = [
(0.3, "long_dominant"),
(0.5, "equal_weight"),
(0.7, "short_dominant")
]
for ensemble_weight, file_suffix in weight_configs:
save_path = os.path.join(save_dir, f"prediction_{file_suffix}.png")
results, fig = plot_prf_prediction(
sequence=example_sequence,
short_threshold=0.6,
long_threshold=0.75,
ensemble_weight=ensemble_weight,
title=f"移码概率预测 (集成权重 {ensemble_weight:.1f}:{1-ensemble_weight:.1f})",
save_path=save_path,
dpi=300
)
print(f"图片已保存至: {save_path}")
# 不显示图片,直接关闭
plt.close(fig)
print("所有集成权重配置的图片都已保存完成")
return True
except Exception as e:
print(f"保存图片过程中出错: {str(e)}")
return False
def example_direct_predictor_usage():
"""直接使用PRFPredictor类的示例"""
print("=" * 50)
print("直接使用PRFPredictor类绘图示例")
print("=" * 50)
try:
# 直接创建预测器实例
predictor = PRFPredictor()
# 示例序列
example_sequence = (
"ATGCGTACGTTAGCGATCGATCGTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGC"
"GATCGATCGTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCGATCGATCGTAGCT"
"AGCTAGCTAGCTAGCTAGCTAGCTAGCGATCGATCGTAGCTAGCTAGCTAGCTAG"
)
# 使用类方法绘图,展示自定义集成权重
results, fig = predictor.plot_sequence_prediction(
sequence=example_sequence,
short_threshold=0.65,
long_threshold=0.8,
ensemble_weight=0.3, # 自定义集成权重
title="使用PRFPredictor类的绘图结果 (集成权重 3:7)"
)
print(f"预测完成!共处理 {len(results)} 个位置")
print(f"使用集成权重比例: Short:{0.3:.1f}, Long:{0.7:.1f}")
# 显示详细结果
print("\n前10个预测结果:")
columns_to_show = ['Position', 'Short_Probability', 'Long_Probability', 'Ensemble_Probability']
print(results[columns_to_show].head(10))
# 显示集成权重信息
if 'Ensemble_Weights' in results.columns:
print(f"\n集成权重配置: {results['Ensemble_Weights'].iloc[0]}")
plt.show()
return results, fig
except Exception as e:
print(f"使用PRFPredictor类时出错: {str(e)}")
return None, None
def example_new_api_usage():
"""新API使用示例"""
print("=" * 50)
print("新API方法使用示例")
print("=" * 50)
try:
# 直接创建预测器实例
predictor = PRFPredictor()
# 示例序列
example_sequence = (
"ATGCGTACGTTAGCGATCGATCGTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGC"
"GATCGATCGTAGCTAGCTAGCTAGCTAGCTAGCTAGCTAGCGATCGATCGTAGCT"
)
print("1. 使用新的 predict_sequence() 方法:")
results = predictor.predict_sequence(
sequence=example_sequence,
ensemble_weight=0.3
)
print(f" 序列预测完成: {len(results)} 个位置")
print(f" 主要输出字段: {[col for col in results.columns if 'Probability' in col]}")
print("\n2. 使用新的 predict_regions() 方法:")
# 模拟一些399bp区域序列
region_sequences = [example_sequence + "A" * (399 - len(example_sequence))]
region_results = predictor.predict_regions(
sequences=region_sequences,
ensemble_weight=0.4
)
print(f" 区域预测完成: {len(region_results)} 个序列")
print(f" 主要输出字段: {[col for col in region_results.columns if 'Probability' in col or 'Sequence' in col]}")
# 显示统计
print("\n3. 结果统计:")
print(f" 序列预测平均集成概率: {results['Ensemble_Probability'].mean():.3f}")
print(f" 区域预测平均集成概率: {region_results['Ensemble_Probability'].mean():.3f}")
return results, region_results
except Exception as e:
print(f"新API使用时出错: {str(e)}")
return None, None
def main():
"""主函数"""
print("FScanpy 序列预测绘图功能演示")
print("=" * 60)
print("新功能:规范化的集成权重参数 (ensemble_weight)")
print("权重范围0.0 到 1.0 (对应 Short模型的权重Long模型权重 = 1 - ensemble_weight)")
print("新命名Ensemble_Probability 替代 Voting_Probability")
print("=" * 60)
examples = [
("1. 基础绘图示例", example_basic_plotting),
("2. 自定义集成权重示例", example_custom_ensemble_weights),
("3. 集成权重对比示例", example_ensemble_comparison),
("4. 保存图片示例", example_save_plot),
("5. 直接使用PRFPredictor类示例", example_direct_predictor_usage),
("6. 新API方法使用示例", example_new_api_usage)
]
for name, func in examples:
print(f"\n{name}")
try:
result = func()
if result is not None and result != False:
print("✓ 示例执行成功")
else:
print("✗ 示例执行失败")
except Exception as e:
print(f"✗ 示例执行出错: {str(e)}")
print("-" * 50)
print("\n演示完成!")
print("\n📊 新功能总结:")
print("1. plot_prf_prediction(): 便捷的绘图函数")
print("2. PRFPredictor.plot_sequence_prediction(): 类方法绘图")
print("3. PRFPredictor.predict_sequence(): 序列滑动窗口预测替代predict_full")
print("4. PRFPredictor.predict_regions(): 区域批量预测替代predict_region")
print("5. 支持自定义阈值、标题、保存路径等参数")
print("6. 新增 ensemble_weight 参数,可调节两个模型的集成权重比例")
print("\n⚖️ 集成权重示例:")
print(" - ensemble_weight=0.2: Short模型20%, Long模型80% (Long主导)")
print(" - ensemble_weight=0.4: Short模型40%, Long模型60% (默认平衡)")
print(" - ensemble_weight=0.5: Short模型50%, Long模型50% (等权重)")
print(" - ensemble_weight=0.7: Short模型70%, Long模型30% (Short主导)")
print("\n📂 输出字段:")
print(" - Short_Probability: Short模型(HistGB)预测概率")
print(" - Long_Probability: Long模型(BiLSTM-CNN)预测概率")
print(" - Ensemble_Probability: 集成预测概率(主要结果)")
print(" - Ensemble_Weights: 权重配置信息")
print(" - Short_Sequence: 33bp序列")
print(" - Long_Sequence: 399bp序列")
print("7. 自动保存PNG和PDF两种格式")
if __name__ == "__main__":
main()

View File

@ -7,7 +7,7 @@ FScanpy is a Python package designed to predict Programmed Ribosomal Frameshifti
FScanpy is a Python package dedicated to predicting Programmed Ribosomal Frameshifting (PRF) sites in DNA sequences. It integrates machine learning models (Gradient Boosting and BiLSTM-CNN) along with the FScanR package to furnish precise PRF predictions. Users are capable of employing three types of data as input: the entire cDNA/mRNA sequence that requires prediction, the nucleotide sequence in the vicinity of the suspected frameshift site, and the peptide library blastx results of the species or related species. It anticipates the input sequence to be in the + strand and can be integrated with FScanR to augment the accuracy.
![Machine learning models](/image/ML.png)
For the prediction of the entire sequence, FScanpy adopts a sliding window approach to scan the entire sequence and predict the PRF sites. For regional prediction, it is based on the 33-bp and 399-bp sequences in the 0 reading frame around the suspected frameshift site. Initially, the Gradient Boosting model will predict the potential PRF sites within the scanning window. If the predicted probability exceeds the threshold, the BiLSTM-CNN model will predict the PRF sites in the 399bp sequence.Then,VotingClassifier will combine the two models to make the final prediction.
For the prediction of the entire sequence, FScanpy adopts a sliding window approach to scan the entire sequence and predict the PRF sites. For regional prediction, it is based on the 33-bp and 399-bp sequences in the 0 reading frame around the suspected frameshift site. Initially, the Short model (HistGradientBoosting) will predict the potential PRF sites within the scanning window. If the predicted probability exceeds the threshold, the Long model (BiLSTM-CNN) will predict the PRF sites in the 399bp sequence. Then, ensemble weighting combines the two models to make the final prediction.
For PRF detection from BLASTX output, [FScanR](https://github.com/seanchen607/FScanR.git) identifies potential PRF sites from BLASTX alignment results, acquires the two hits of the same query sequence, and then utilizes frameDist_cutoff, mismatch_cutoff, and evalue_cutoff to filter the hits. Finally, FScanpy is utilized to predict the probability of PRF sites.
@ -17,8 +17,8 @@ For PRF detection from BLASTX output, [FScanR](https://github.com/seanchen607/FS
### Key features of FScanpy include:
- Integration of two predictive models:
- [Gradient Boosting](https://tensorflow.google.cn/tutorials/estimator/boosted_trees?hl=en): Analyzes local sequence features centered around potential frameshift sites (10 codons).
- [BiLSTM-CNN](https://paperswithcode.com/method/cnn-bilstm): Analyzes broader sequence features (100 codons).
- Short Model (HistGradientBoosting): Analyzes local sequence features centered around potential frameshift sites (33bp).
- Long Model (BiLSTM-CNN): Analyzes broader sequence features (399bp).
- Supports PRF prediction across various species.
- Can be combined with [FScanR](https://github.com/seanchen607/FScanR.git) for enhanced accuracy.
@ -29,7 +29,7 @@ For PRF detection from BLASTX output, [FScanR](https://github.com/seanchen607/FS
pip install FScanpy
```
### 2. Clone from [GitHub](https://github.com/.../FScanpy.git)
### 2. Clone from GitHub
```bash
git clone https://github.com/.../FScanpy.git
cd your_project_directory
@ -39,7 +39,6 @@ pip install -e .
## Methods and Usage
### 1. Load model and test data
Test data can be found in `FScanpy/data/test_data`,you can use the `list_test_data()` method to list all the test data and the `get_test_data_path()` method to get the path of the test data:
```python
from FScanpy import PRFPredictor
from FScanpy.data import get_test_data_path, list_test_data
@ -51,56 +50,38 @@ region_example = get_test_data_path('region_example.xlsx')
```
### 2. Predict PRF Sites in a Full Sequence
Use the `predict_full()` method to scan the entire sequence,you can use the `window_size` parameter to adjust the scanning window size(default is 3) and the `gb_threshold` parameter to adjust the Gradient Boosting model fitting threshold(default is 0.1) for faster or more accurate prediction:
Use the `predict_sequence()` method to scan the entire sequence:
```python
'''
Args:
sequence: mRNA sequence
window_size: scanning window size (default is 3)
gb_threshold: Gradient Boosting model threshold (default is 0.1)
Returns:
results: DataFrame containing prediction probabilities
'''
results = predictor.predict_full(sequence='ATGCGTACGTATGCGTACGTATGCGTACGT',
window_size=3, # Scanning window size
gb_threshold=0.1, # Gradient Boosting model threshold
plot=True) # Whether to plot the prediction results
fig.savefig('predict_full.png')
results = predictor.predict_sequence(
sequence='ATGCGTACGTATGCGTACGTATGCGTACGT',
window_size=3, # Scanning window size
short_threshold=0.1, # Short model threshold
ensemble_weight=0.4 # Ensemble weight (Short:Long = 0.4:0.6)
)
# With visualization
results, fig = predictor.plot_sequence_prediction(
sequence='ATGCGTACGTATGCGTACGTATGCGTACGT',
ensemble_weight=0.4
)
```
### 3. Predict PRF in Specific Regions
Use the `predict_region()` method to predict PRF in known regions of interest:
Use the `predict_regions()` method to predict PRF in known regions of interest:
```python
'''
Args:
seq: 399bp sequence
gb_threshold: GB model probability threshold (default is 0.1)
Returns:
DataFrame: 包含所有序列预测概率的DataFrame
'''
import pandas as pd
region_example = pd.read_excel(get_test_data_path('region_example.xlsx'))
results = predictor.predict_region(seq=region_example['399bp'])
results = predictor.predict_regions(
sequences=region_example['399bp'],
ensemble_weight=0.4
)
```
### 4. Identify PRF Sites from BLASTX Output
BLASTX Output should contain the following columns: `qseqid`, `sseqid`, `pident`, `length`, `mismatch`, `gapopen`, `qstart`, `qend`, `sstart`, `send`, `evalue`, `bitscore`, `qframe`, `sframe`.
FScanR result contains `DNA_seqid`, `FS_start`, `FS_end`, `FS_type`,`Pep_seqid`, `Pep_FS_start`, `Pep_FS_end`, `Strand` columns.
Use the FScanR function to identify potential PRF sites from BLASTX alignment results:
```python
"""
identify PRF sites from BLASTX output
Args:
blastx_output: BLASTX output DataFrame
mismatch_cutoff: mismatch threshold
evalue_cutoff: E-value threshold
frameDist_cutoff: frame distance threshold
Returns:
pd.DataFrame: DataFrame containing PRF site information
"""
from FScanpy.utils import fscanr
blastx_output = pd.read_excel(get_test_data_path('blastx_example.xlsx'))
fscanr_result = fscanr(blastx_output,
@ -109,56 +90,43 @@ fscanr_result = fscanr(blastx_output,
frameDist_cutoff=10) # Frame distance threshold
```
### 5. Extract PRF Sites from BLASTX Output or your Sequence Data and evaluate it by FScanpy
Use the `extract_prf_regions()` method to extract PRF site sequences from mRNA sequences,it based on the `FS_start` column of the FScanR output contact with the `DNA_seqid` column of the input mRNA sequence file to extract the 33bp and 399bp sequences around the PRF sites in 0 reading frame:
### 5. Extract PRF Sites and Evaluate
Use the `extract_prf_regions()` method to extract PRF site sequences from mRNA sequences:
```python
"""
extract PRF site sequences from mRNA sequences
Args:
mrna_file: mRNA sequence file path (FASTA format)
prf_data: FScanR output PRF site data or your suspected PRF site data which at least contains `DNA_seqid` `FS_start` `strand` columns
Returns:
pd.DataFrame: DataFrame containing 33bp and 399bp sequences
"""
from FScanpy.utils import extract_prf_regions
prf_regions = extract_prf_regions(mrna_file=get_test_data_path('mrna_example.fasta'),
prf_data=fscanr_result)
prf_results = predictor.predict_region (prf_regions['399bp'])
prf_regions = extract_prf_regions(
mrna_file=get_test_data_path('mrna_example.fasta'),
prf_data=fscanr_result
)
prf_results = predictor.predict_regions(prf_regions['399bp'])
```
## Total Test
## Complete Workflow Example
```python
from FScanpy import PRFPredictor
from FScanpy import PRFPredictor, predict_prf, plot_prf_prediction
from FScanpy.data import get_test_data_path, list_test_data
predictor = PRFPredictor() # load model
list_test_data() # list all the test data
blastx_file = get_test_data_path('blastx_example.xlsx')
mrna_file = get_test_data_path('mrna_example.fasta')
region_example = get_test_data_path('region_example.xlsx')
results = predictor.predict_full(sequence='ATGCGTACGTATGCGTACGTATGCGTACGT',
window_size=3, # Scanning window size
gb_threshold=0.1, # Gradient Boosting model threshold
plot=True)
from FScanpy.utils import fscanr, extract_prf_regions
import pandas as pd
region_example = pd.read_excel(get_test_data_path('region_example.xlsx'))
results = predictor.predict_region(seq=region_example['399bp'])
from FScanpy.utils import fscanr
# Initialize predictor
predictor = PRFPredictor()
# Method 1: Sequence prediction
sequence = 'ATGCGTACGTATGCGTACGTATGCGTACGT'
results = predict_prf(sequence=sequence, ensemble_weight=0.4)
# Method 2: Region prediction
region_data = pd.read_excel(get_test_data_path('region_example.xlsx'))
results = predict_prf(data=region_data, ensemble_weight=0.4)
# Method 3: BLASTX pipeline
blastx_output = pd.read_excel(get_test_data_path('blastx_example.xlsx'))
fscanr_result = fscanr(blastx_output,
mismatch_cutoff=10, # Allowed mismatches
evalue_cutoff=1e-5, # E-value threshold
frameDist_cutoff=10)
from FScanpy.utils import extract_prf_regions
prf_regions = extract_prf_regions(mrna_file=get_test_data_path('mrna_example.fasta'),
prf_data=fscanr_result)
prf_results = predictor.predict_region (prf_regions['399bp'])
fscanr_result = fscanr(blastx_output, mismatch_cutoff=10, evalue_cutoff=1e-5, frameDist_cutoff=10)
prf_regions = extract_prf_regions(get_test_data_path('mrna_example.fasta'), fscanr_result)
prf_results = predictor.predict_regions(prf_regions['399bp'])
# Visualization
results, fig = plot_prf_prediction(sequence, ensemble_weight=0.4, save_path='prediction.png')
```
## Citation