FScanpy-package/FScanpy/__init__.py

205 lines
7.7 KiB
Python
Raw Normal View History

2025-03-18 11:21:54 +08:00
from .predictor import PRFPredictor
2025-05-29 17:58:48 +08:00
from . import data
2025-03-18 11:21:54 +08:00
import pandas as pd
import numpy as np
from typing import Union, List, Dict
__version__ = '0.3.0'
__author__ = ''
__email__ = ''
2025-05-29 17:58:48 +08:00
__all__ = ['PRFPredictor', 'predict_prf', 'plot_prf_prediction', 'data', '__version__', '__author__', '__email__']
2025-03-18 11:21:54 +08:00
def predict_prf(
sequence: Union[str, List[str], None] = None,
data: Union[pd.DataFrame, None] = None,
window_size: int = 3,
2025-05-29 17:58:48 +08:00
short_threshold: float = 0.1,
ensemble_weight: float = 0.4,
2025-03-18 11:21:54 +08:00
model_dir: str = None
) -> pd.DataFrame:
"""
PRF位点预测函数
Args:
sequence: 单个或多个DNA序列用于滑动窗口预测
2025-05-29 17:58:48 +08:00
data: DataFrame数据必须包含'Long_Sequence''399bp'用于区域预测
2025-03-18 11:21:54 +08:00
window_size: 滑动窗口大小默认为3
2025-05-29 17:58:48 +08:00
short_threshold: Short模型(HistGB)概率阈值默认为0.1
ensemble_weight: Short模型在集成中的权重默认为0.4Long权重为0.6
2025-03-18 11:21:54 +08:00
model_dir: 模型文件目录路径可选
Returns:
2025-05-29 17:58:48 +08:00
pandas.DataFrame: 预测结果包含以下主要字段
- Short_Probability: Short模型预测概率
- Long_Probability: Long模型预测概率
- Ensemble_Probability: 集成预测概率主要结果
- Ensemble_Weights: 权重配置信息
2025-03-18 11:21:54 +08:00
Examples:
# 1. 单条序列滑动窗口预测
>>> from FScanpy import predict_prf
>>> sequence = "ATGCGTACGT..."
>>> results = predict_prf(sequence=sequence)
# 2. 多条序列滑动窗口预测
>>> sequences = ["ATGCGTACGT...", "GCTATAGCAT..."]
>>> results = predict_prf(sequence=sequences)
2025-05-29 17:58:48 +08:00
# 3. 自定义集成权重比例
>>> results = predict_prf(sequence=sequence, ensemble_weight=0.3) # 3:7 比例
# 4. DataFrame区域预测
2025-03-18 11:21:54 +08:00
>>> import pandas as pd
>>> data = pd.DataFrame({
2025-05-29 17:58:48 +08:00
... 'Long_Sequence': ['ATGCGT...', 'GCTATAG...'] # 或使用 '399bp'
2025-03-18 11:21:54 +08:00
... })
>>> results = predict_prf(data=data)
"""
predictor = PRFPredictor(model_dir=model_dir)
# 验证输入参数
if sequence is None and data is None:
raise ValueError("必须提供sequence或data参数之一")
if sequence is not None and data is not None:
raise ValueError("sequence和data参数不能同时提供")
2025-05-29 17:58:48 +08:00
if not (0.0 <= ensemble_weight <= 1.0):
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
2025-03-18 11:21:54 +08:00
# 滑动窗口预测模式
if sequence is not None:
if isinstance(sequence, str):
# 单条序列预测
2025-05-29 17:58:48 +08:00
return predictor.predict_sequence(
sequence, window_size, short_threshold, ensemble_weight)
2025-03-18 11:21:54 +08:00
elif isinstance(sequence, (list, tuple)):
# 多条序列预测
results = []
for i, seq in enumerate(sequence, 1):
try:
2025-05-29 17:58:48 +08:00
result = predictor.predict_sequence(
seq, window_size, short_threshold, ensemble_weight)
2025-03-18 11:21:54 +08:00
result['Sequence_ID'] = f'seq_{i}'
results.append(result)
except Exception as e:
print(f"警告:序列 {i} 预测失败 - {str(e)}")
return pd.concat(results, ignore_index=True) if results else pd.DataFrame()
# 区域化预测模式
else:
if not isinstance(data, pd.DataFrame):
raise ValueError("data参数必须是pandas DataFrame类型")
2025-05-29 17:58:48 +08:00
# 检查列名(支持新旧两种命名)
seq_column = None
if 'Long_Sequence' in data.columns:
seq_column = 'Long_Sequence'
elif '399bp' in data.columns:
seq_column = '399bp'
else:
raise ValueError("DataFrame必须包含'Long_Sequence''399bp'")
2025-03-18 11:21:54 +08:00
# 调用区域预测函数
try:
2025-05-29 17:58:48 +08:00
results = predictor.predict_regions(
data[seq_column], short_threshold, ensemble_weight)
2025-03-18 11:21:54 +08:00
# 添加原始数据的其他列
for col in data.columns:
2025-05-29 17:58:48 +08:00
if col not in ['Long_Sequence', '399bp', 'Short_Sequence', '33bp']:
2025-03-18 11:21:54 +08:00
results[col] = data[col].values
return results
except Exception as e:
print(f"警告:区域预测失败 - {str(e)}")
# 创建空结果
2025-05-29 17:58:48 +08:00
long_weight = 1.0 - ensemble_weight
2025-03-18 11:21:54 +08:00
results = pd.DataFrame({
2025-05-29 17:58:48 +08:00
'Short_Probability': [0.0] * len(data),
'Long_Probability': [0.0] * len(data),
'Ensemble_Probability': [0.0] * len(data),
'Ensemble_Weights': [f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}'] * len(data)
2025-03-18 11:21:54 +08:00
})
# 添加原始数据列
for col in data.columns:
results[col] = data[col].values
2025-05-29 17:58:48 +08:00
return results
def plot_prf_prediction(
sequence: str,
window_size: int = 3,
short_threshold: float = 0.65,
long_threshold: float = 0.8,
ensemble_weight: float = 0.4,
title: str = None,
save_path: str = None,
figsize: tuple = (12, 6),
dpi: int = 300,
model_dir: str = None
) -> tuple:
"""
绘制序列PRF预测结果的移码概率图
Args:
sequence: 输入DNA序列
window_size: 滑动窗口大小默认为3
short_threshold: Short模型(HistGB)过滤阈值默认为0.65
long_threshold: Long模型(BiLSTM-CNN)过滤阈值默认为0.8
ensemble_weight: Short模型在集成中的权重默认为0.4Long权重为0.6
title: 图片标题可选
save_path: 保存路径可选如果提供则保存图片
figsize: 图片尺寸默认为(12, 6)
dpi: 图片分辨率默认为300
model_dir: 模型文件目录路径可选
Returns:
tuple: (pd.DataFrame, matplotlib.figure.Figure) 预测结果和图形对象
Examples:
# 1. 简单绘图
>>> from FScanpy import plot_prf_prediction
>>> sequence = "ATGCGTACGT..."
>>> results, fig = plot_prf_prediction(sequence)
>>> plt.show()
# 2. 自定义阈值和集成权重
>>> results, fig = plot_prf_prediction(
... sequence,
... short_threshold=0.7,
... long_threshold=0.85,
... ensemble_weight=0.3, # 3:7 权重比例
... title="自定义权重的预测结果",
... save_path="prediction_result.png"
... )
# 3. 等权重组合
>>> results, fig = plot_prf_prediction(
... sequence,
... ensemble_weight=0.5 # 5:5 等权重
... )
# 4. Long模型主导
>>> results, fig = plot_prf_prediction(
... sequence,
... ensemble_weight=0.2 # 2:8 权重Long模型主导
... )
"""
if not (0.0 <= ensemble_weight <= 1.0):
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
predictor = PRFPredictor(model_dir=model_dir)
return predictor.plot_sequence_prediction(
sequence=sequence,
window_size=window_size,
short_threshold=short_threshold,
long_threshold=long_threshold,
ensemble_weight=ensemble_weight,
title=title,
save_path=save_path,
figsize=figsize,
dpi=dpi
)