205 lines
7.7 KiB
Python
205 lines
7.7 KiB
Python
from .predictor import PRFPredictor
|
||
from . import data
|
||
import pandas as pd
|
||
import numpy as np
|
||
from typing import Union, List, Dict
|
||
|
||
__version__ = '0.3.0'
|
||
__author__ = ''
|
||
__email__ = ''
|
||
|
||
__all__ = ['PRFPredictor', 'predict_prf', 'plot_prf_prediction', 'data', '__version__', '__author__', '__email__']
|
||
|
||
def predict_prf(
|
||
sequence: Union[str, List[str], None] = None,
|
||
data: Union[pd.DataFrame, None] = None,
|
||
window_size: int = 3,
|
||
short_threshold: float = 0.1,
|
||
ensemble_weight: float = 0.4,
|
||
model_dir: str = None
|
||
) -> pd.DataFrame:
|
||
"""
|
||
PRF位点预测函数
|
||
|
||
Args:
|
||
sequence: 单个或多个DNA序列,用于滑动窗口预测
|
||
data: DataFrame数据,必须包含'Long_Sequence'或'399bp'列,用于区域预测
|
||
window_size: 滑动窗口大小(默认为3)
|
||
short_threshold: Short模型(HistGB)概率阈值(默认为0.1)
|
||
ensemble_weight: Short模型在集成中的权重(默认为0.4,Long权重为0.6)
|
||
model_dir: 模型文件目录路径(可选)
|
||
|
||
Returns:
|
||
pandas.DataFrame: 预测结果,包含以下主要字段:
|
||
- Short_Probability: Short模型预测概率
|
||
- Long_Probability: Long模型预测概率
|
||
- Ensemble_Probability: 集成预测概率(主要结果)
|
||
- Ensemble_Weights: 权重配置信息
|
||
|
||
Examples:
|
||
# 1. 单条序列滑动窗口预测
|
||
>>> from FScanpy import predict_prf
|
||
>>> sequence = "ATGCGTACGT..."
|
||
>>> results = predict_prf(sequence=sequence)
|
||
|
||
# 2. 多条序列滑动窗口预测
|
||
>>> sequences = ["ATGCGTACGT...", "GCTATAGCAT..."]
|
||
>>> results = predict_prf(sequence=sequences)
|
||
|
||
# 3. 自定义集成权重比例
|
||
>>> results = predict_prf(sequence=sequence, ensemble_weight=0.3) # 3:7 比例
|
||
|
||
# 4. DataFrame区域预测
|
||
>>> import pandas as pd
|
||
>>> data = pd.DataFrame({
|
||
... 'Long_Sequence': ['ATGCGT...', 'GCTATAG...'] # 或使用 '399bp'
|
||
... })
|
||
>>> results = predict_prf(data=data)
|
||
"""
|
||
predictor = PRFPredictor(model_dir=model_dir)
|
||
|
||
# 验证输入参数
|
||
if sequence is None and data is None:
|
||
raise ValueError("必须提供sequence或data参数之一")
|
||
if sequence is not None and data is not None:
|
||
raise ValueError("sequence和data参数不能同时提供")
|
||
if not (0.0 <= ensemble_weight <= 1.0):
|
||
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
|
||
|
||
# 滑动窗口预测模式
|
||
if sequence is not None:
|
||
if isinstance(sequence, str):
|
||
# 单条序列预测
|
||
return predictor.predict_sequence(
|
||
sequence, window_size, short_threshold, ensemble_weight)
|
||
elif isinstance(sequence, (list, tuple)):
|
||
# 多条序列预测
|
||
results = []
|
||
for i, seq in enumerate(sequence, 1):
|
||
try:
|
||
result = predictor.predict_sequence(
|
||
seq, window_size, short_threshold, ensemble_weight)
|
||
result['Sequence_ID'] = f'seq_{i}'
|
||
results.append(result)
|
||
except Exception as e:
|
||
print(f"警告:序列 {i} 预测失败 - {str(e)}")
|
||
return pd.concat(results, ignore_index=True) if results else pd.DataFrame()
|
||
|
||
# 区域化预测模式
|
||
else:
|
||
if not isinstance(data, pd.DataFrame):
|
||
raise ValueError("data参数必须是pandas DataFrame类型")
|
||
|
||
# 检查列名(支持新旧两种命名)
|
||
seq_column = None
|
||
if 'Long_Sequence' in data.columns:
|
||
seq_column = 'Long_Sequence'
|
||
elif '399bp' in data.columns:
|
||
seq_column = '399bp'
|
||
else:
|
||
raise ValueError("DataFrame必须包含'Long_Sequence'或'399bp'列")
|
||
|
||
# 调用区域预测函数
|
||
try:
|
||
results = predictor.predict_regions(
|
||
data[seq_column], short_threshold, ensemble_weight)
|
||
|
||
# 添加原始数据的其他列
|
||
for col in data.columns:
|
||
if col not in ['Long_Sequence', '399bp', 'Short_Sequence', '33bp']:
|
||
results[col] = data[col].values
|
||
|
||
return results
|
||
|
||
except Exception as e:
|
||
print(f"警告:区域预测失败 - {str(e)}")
|
||
# 创建空结果
|
||
long_weight = 1.0 - ensemble_weight
|
||
results = pd.DataFrame({
|
||
'Short_Probability': [0.0] * len(data),
|
||
'Long_Probability': [0.0] * len(data),
|
||
'Ensemble_Probability': [0.0] * len(data),
|
||
'Ensemble_Weights': [f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}'] * len(data)
|
||
})
|
||
|
||
# 添加原始数据列
|
||
for col in data.columns:
|
||
results[col] = data[col].values
|
||
|
||
return results
|
||
|
||
def plot_prf_prediction(
|
||
sequence: str,
|
||
window_size: int = 3,
|
||
short_threshold: float = 0.65,
|
||
long_threshold: float = 0.8,
|
||
ensemble_weight: float = 0.4,
|
||
title: str = None,
|
||
save_path: str = None,
|
||
figsize: tuple = (12, 6),
|
||
dpi: int = 300,
|
||
model_dir: str = None
|
||
) -> tuple:
|
||
"""
|
||
绘制序列PRF预测结果的移码概率图
|
||
|
||
Args:
|
||
sequence: 输入DNA序列
|
||
window_size: 滑动窗口大小(默认为3)
|
||
short_threshold: Short模型(HistGB)过滤阈值(默认为0.65)
|
||
long_threshold: Long模型(BiLSTM-CNN)过滤阈值(默认为0.8)
|
||
ensemble_weight: Short模型在集成中的权重(默认为0.4,Long权重为0.6)
|
||
title: 图片标题(可选)
|
||
save_path: 保存路径(可选,如果提供则保存图片)
|
||
figsize: 图片尺寸(默认为(12, 6))
|
||
dpi: 图片分辨率(默认为300)
|
||
model_dir: 模型文件目录路径(可选)
|
||
|
||
Returns:
|
||
tuple: (pd.DataFrame, matplotlib.figure.Figure) 预测结果和图形对象
|
||
|
||
Examples:
|
||
# 1. 简单绘图
|
||
>>> from FScanpy import plot_prf_prediction
|
||
>>> sequence = "ATGCGTACGT..."
|
||
>>> results, fig = plot_prf_prediction(sequence)
|
||
>>> plt.show()
|
||
|
||
# 2. 自定义阈值和集成权重
|
||
>>> results, fig = plot_prf_prediction(
|
||
... sequence,
|
||
... short_threshold=0.7,
|
||
... long_threshold=0.85,
|
||
... ensemble_weight=0.3, # 3:7 权重比例
|
||
... title="自定义权重的预测结果",
|
||
... save_path="prediction_result.png"
|
||
... )
|
||
|
||
# 3. 等权重组合
|
||
>>> results, fig = plot_prf_prediction(
|
||
... sequence,
|
||
... ensemble_weight=0.5 # 5:5 等权重
|
||
... )
|
||
|
||
# 4. Long模型主导
|
||
>>> results, fig = plot_prf_prediction(
|
||
... sequence,
|
||
... ensemble_weight=0.2 # 2:8 权重,Long模型主导
|
||
... )
|
||
"""
|
||
if not (0.0 <= ensemble_weight <= 1.0):
|
||
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
|
||
|
||
predictor = PRFPredictor(model_dir=model_dir)
|
||
|
||
return predictor.plot_sequence_prediction(
|
||
sequence=sequence,
|
||
window_size=window_size,
|
||
short_threshold=short_threshold,
|
||
long_threshold=long_threshold,
|
||
ensemble_weight=ensemble_weight,
|
||
title=title,
|
||
save_path=save_path,
|
||
figsize=figsize,
|
||
dpi=dpi
|
||
) |