2025-03-18 11:21:54 +08:00
|
|
|
|
from .predictor import PRFPredictor
|
2025-05-29 17:58:48 +08:00
|
|
|
|
from . import data
|
2025-03-18 11:21:54 +08:00
|
|
|
|
import pandas as pd
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
from typing import Union, List, Dict
|
|
|
|
|
|
|
|
|
|
|
|
__version__ = '0.3.0'
|
|
|
|
|
|
__author__ = ''
|
|
|
|
|
|
__email__ = ''
|
|
|
|
|
|
|
2025-05-29 17:58:48 +08:00
|
|
|
|
__all__ = ['PRFPredictor', 'predict_prf', 'plot_prf_prediction', 'data', '__version__', '__author__', '__email__']
|
2025-03-18 11:21:54 +08:00
|
|
|
|
|
|
|
|
|
|
def predict_prf(
|
|
|
|
|
|
sequence: Union[str, List[str], None] = None,
|
|
|
|
|
|
data: Union[pd.DataFrame, None] = None,
|
|
|
|
|
|
window_size: int = 3,
|
2025-05-29 17:58:48 +08:00
|
|
|
|
short_threshold: float = 0.1,
|
|
|
|
|
|
ensemble_weight: float = 0.4,
|
2025-03-18 11:21:54 +08:00
|
|
|
|
model_dir: str = None
|
|
|
|
|
|
) -> pd.DataFrame:
|
|
|
|
|
|
"""
|
|
|
|
|
|
PRF位点预测函数
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
sequence: 单个或多个DNA序列,用于滑动窗口预测
|
2025-05-29 17:58:48 +08:00
|
|
|
|
data: DataFrame数据,必须包含'Long_Sequence'或'399bp'列,用于区域预测
|
2025-03-18 11:21:54 +08:00
|
|
|
|
window_size: 滑动窗口大小(默认为3)
|
2025-05-29 17:58:48 +08:00
|
|
|
|
short_threshold: Short模型(HistGB)概率阈值(默认为0.1)
|
|
|
|
|
|
ensemble_weight: Short模型在集成中的权重(默认为0.4,Long权重为0.6)
|
2025-03-18 11:21:54 +08:00
|
|
|
|
model_dir: 模型文件目录路径(可选)
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
2025-05-29 17:58:48 +08:00
|
|
|
|
pandas.DataFrame: 预测结果,包含以下主要字段:
|
|
|
|
|
|
- Short_Probability: Short模型预测概率
|
|
|
|
|
|
- Long_Probability: Long模型预测概率
|
|
|
|
|
|
- Ensemble_Probability: 集成预测概率(主要结果)
|
|
|
|
|
|
- Ensemble_Weights: 权重配置信息
|
2025-03-18 11:21:54 +08:00
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
|
# 1. 单条序列滑动窗口预测
|
|
|
|
|
|
>>> from FScanpy import predict_prf
|
|
|
|
|
|
>>> sequence = "ATGCGTACGT..."
|
|
|
|
|
|
>>> results = predict_prf(sequence=sequence)
|
|
|
|
|
|
|
|
|
|
|
|
# 2. 多条序列滑动窗口预测
|
|
|
|
|
|
>>> sequences = ["ATGCGTACGT...", "GCTATAGCAT..."]
|
|
|
|
|
|
>>> results = predict_prf(sequence=sequences)
|
|
|
|
|
|
|
2025-05-29 17:58:48 +08:00
|
|
|
|
# 3. 自定义集成权重比例
|
|
|
|
|
|
>>> results = predict_prf(sequence=sequence, ensemble_weight=0.3) # 3:7 比例
|
|
|
|
|
|
|
|
|
|
|
|
# 4. DataFrame区域预测
|
2025-03-18 11:21:54 +08:00
|
|
|
|
>>> import pandas as pd
|
|
|
|
|
|
>>> data = pd.DataFrame({
|
2025-05-29 17:58:48 +08:00
|
|
|
|
... 'Long_Sequence': ['ATGCGT...', 'GCTATAG...'] # 或使用 '399bp'
|
2025-03-18 11:21:54 +08:00
|
|
|
|
... })
|
|
|
|
|
|
>>> results = predict_prf(data=data)
|
|
|
|
|
|
"""
|
|
|
|
|
|
predictor = PRFPredictor(model_dir=model_dir)
|
|
|
|
|
|
|
|
|
|
|
|
# 验证输入参数
|
|
|
|
|
|
if sequence is None and data is None:
|
|
|
|
|
|
raise ValueError("必须提供sequence或data参数之一")
|
|
|
|
|
|
if sequence is not None and data is not None:
|
|
|
|
|
|
raise ValueError("sequence和data参数不能同时提供")
|
2025-05-29 17:58:48 +08:00
|
|
|
|
if not (0.0 <= ensemble_weight <= 1.0):
|
|
|
|
|
|
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
|
2025-03-18 11:21:54 +08:00
|
|
|
|
|
|
|
|
|
|
# 滑动窗口预测模式
|
|
|
|
|
|
if sequence is not None:
|
|
|
|
|
|
if isinstance(sequence, str):
|
|
|
|
|
|
# 单条序列预测
|
2025-05-29 17:58:48 +08:00
|
|
|
|
return predictor.predict_sequence(
|
|
|
|
|
|
sequence, window_size, short_threshold, ensemble_weight)
|
2025-03-18 11:21:54 +08:00
|
|
|
|
elif isinstance(sequence, (list, tuple)):
|
|
|
|
|
|
# 多条序列预测
|
|
|
|
|
|
results = []
|
|
|
|
|
|
for i, seq in enumerate(sequence, 1):
|
|
|
|
|
|
try:
|
2025-05-29 17:58:48 +08:00
|
|
|
|
result = predictor.predict_sequence(
|
|
|
|
|
|
seq, window_size, short_threshold, ensemble_weight)
|
2025-03-18 11:21:54 +08:00
|
|
|
|
result['Sequence_ID'] = f'seq_{i}'
|
|
|
|
|
|
results.append(result)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"警告:序列 {i} 预测失败 - {str(e)}")
|
|
|
|
|
|
return pd.concat(results, ignore_index=True) if results else pd.DataFrame()
|
|
|
|
|
|
|
|
|
|
|
|
# 区域化预测模式
|
|
|
|
|
|
else:
|
|
|
|
|
|
if not isinstance(data, pd.DataFrame):
|
|
|
|
|
|
raise ValueError("data参数必须是pandas DataFrame类型")
|
|
|
|
|
|
|
2025-05-29 17:58:48 +08:00
|
|
|
|
# 检查列名(支持新旧两种命名)
|
|
|
|
|
|
seq_column = None
|
|
|
|
|
|
if 'Long_Sequence' in data.columns:
|
|
|
|
|
|
seq_column = 'Long_Sequence'
|
|
|
|
|
|
elif '399bp' in data.columns:
|
|
|
|
|
|
seq_column = '399bp'
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise ValueError("DataFrame必须包含'Long_Sequence'或'399bp'列")
|
2025-03-18 11:21:54 +08:00
|
|
|
|
|
|
|
|
|
|
# 调用区域预测函数
|
|
|
|
|
|
try:
|
2025-05-29 17:58:48 +08:00
|
|
|
|
results = predictor.predict_regions(
|
|
|
|
|
|
data[seq_column], short_threshold, ensemble_weight)
|
2025-03-18 11:21:54 +08:00
|
|
|
|
|
|
|
|
|
|
# 添加原始数据的其他列
|
|
|
|
|
|
for col in data.columns:
|
2025-05-29 17:58:48 +08:00
|
|
|
|
if col not in ['Long_Sequence', '399bp', 'Short_Sequence', '33bp']:
|
2025-03-18 11:21:54 +08:00
|
|
|
|
results[col] = data[col].values
|
|
|
|
|
|
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"警告:区域预测失败 - {str(e)}")
|
|
|
|
|
|
# 创建空结果
|
2025-05-29 17:58:48 +08:00
|
|
|
|
long_weight = 1.0 - ensemble_weight
|
2025-03-18 11:21:54 +08:00
|
|
|
|
results = pd.DataFrame({
|
2025-05-29 17:58:48 +08:00
|
|
|
|
'Short_Probability': [0.0] * len(data),
|
|
|
|
|
|
'Long_Probability': [0.0] * len(data),
|
|
|
|
|
|
'Ensemble_Probability': [0.0] * len(data),
|
|
|
|
|
|
'Ensemble_Weights': [f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}'] * len(data)
|
2025-03-18 11:21:54 +08:00
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
# 添加原始数据列
|
|
|
|
|
|
for col in data.columns:
|
|
|
|
|
|
results[col] = data[col].values
|
|
|
|
|
|
|
2025-05-29 17:58:48 +08:00
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
def plot_prf_prediction(
|
|
|
|
|
|
sequence: str,
|
|
|
|
|
|
window_size: int = 3,
|
|
|
|
|
|
short_threshold: float = 0.65,
|
|
|
|
|
|
long_threshold: float = 0.8,
|
|
|
|
|
|
ensemble_weight: float = 0.4,
|
|
|
|
|
|
title: str = None,
|
|
|
|
|
|
save_path: str = None,
|
|
|
|
|
|
figsize: tuple = (12, 6),
|
|
|
|
|
|
dpi: int = 300,
|
|
|
|
|
|
model_dir: str = None
|
|
|
|
|
|
) -> tuple:
|
|
|
|
|
|
"""
|
|
|
|
|
|
绘制序列PRF预测结果的移码概率图
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
sequence: 输入DNA序列
|
|
|
|
|
|
window_size: 滑动窗口大小(默认为3)
|
|
|
|
|
|
short_threshold: Short模型(HistGB)过滤阈值(默认为0.65)
|
|
|
|
|
|
long_threshold: Long模型(BiLSTM-CNN)过滤阈值(默认为0.8)
|
|
|
|
|
|
ensemble_weight: Short模型在集成中的权重(默认为0.4,Long权重为0.6)
|
|
|
|
|
|
title: 图片标题(可选)
|
|
|
|
|
|
save_path: 保存路径(可选,如果提供则保存图片)
|
|
|
|
|
|
figsize: 图片尺寸(默认为(12, 6))
|
|
|
|
|
|
dpi: 图片分辨率(默认为300)
|
|
|
|
|
|
model_dir: 模型文件目录路径(可选)
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
tuple: (pd.DataFrame, matplotlib.figure.Figure) 预测结果和图形对象
|
|
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
|
# 1. 简单绘图
|
|
|
|
|
|
>>> from FScanpy import plot_prf_prediction
|
|
|
|
|
|
>>> sequence = "ATGCGTACGT..."
|
|
|
|
|
|
>>> results, fig = plot_prf_prediction(sequence)
|
|
|
|
|
|
>>> plt.show()
|
|
|
|
|
|
|
|
|
|
|
|
# 2. 自定义阈值和集成权重
|
|
|
|
|
|
>>> results, fig = plot_prf_prediction(
|
|
|
|
|
|
... sequence,
|
|
|
|
|
|
... short_threshold=0.7,
|
|
|
|
|
|
... long_threshold=0.85,
|
|
|
|
|
|
... ensemble_weight=0.3, # 3:7 权重比例
|
|
|
|
|
|
... title="自定义权重的预测结果",
|
|
|
|
|
|
... save_path="prediction_result.png"
|
|
|
|
|
|
... )
|
|
|
|
|
|
|
|
|
|
|
|
# 3. 等权重组合
|
|
|
|
|
|
>>> results, fig = plot_prf_prediction(
|
|
|
|
|
|
... sequence,
|
|
|
|
|
|
... ensemble_weight=0.5 # 5:5 等权重
|
|
|
|
|
|
... )
|
|
|
|
|
|
|
|
|
|
|
|
# 4. Long模型主导
|
|
|
|
|
|
>>> results, fig = plot_prf_prediction(
|
|
|
|
|
|
... sequence,
|
|
|
|
|
|
... ensemble_weight=0.2 # 2:8 权重,Long模型主导
|
|
|
|
|
|
... )
|
|
|
|
|
|
"""
|
|
|
|
|
|
if not (0.0 <= ensemble_weight <= 1.0):
|
|
|
|
|
|
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
|
|
|
|
|
|
|
|
|
|
|
|
predictor = PRFPredictor(model_dir=model_dir)
|
|
|
|
|
|
|
|
|
|
|
|
return predictor.plot_sequence_prediction(
|
|
|
|
|
|
sequence=sequence,
|
|
|
|
|
|
window_size=window_size,
|
|
|
|
|
|
short_threshold=short_threshold,
|
|
|
|
|
|
long_threshold=long_threshold,
|
|
|
|
|
|
ensemble_weight=ensemble_weight,
|
|
|
|
|
|
title=title,
|
|
|
|
|
|
save_path=save_path,
|
|
|
|
|
|
figsize=figsize,
|
|
|
|
|
|
dpi=dpi
|
|
|
|
|
|
)
|