FScanpy-package/FScanpy/__init__.py

205 lines
7.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from .predictor import PRFPredictor
from . import data
import pandas as pd
import numpy as np
from typing import Union, List, Dict
__version__ = '0.3.0'
__author__ = ''
__email__ = ''
__all__ = ['PRFPredictor', 'predict_prf', 'plot_prf_prediction', 'data', '__version__', '__author__', '__email__']
def predict_prf(
sequence: Union[str, List[str], None] = None,
data: Union[pd.DataFrame, None] = None,
window_size: int = 3,
short_threshold: float = 0.1,
ensemble_weight: float = 0.4,
model_dir: str = None
) -> pd.DataFrame:
"""
PRF位点预测函数
Args:
sequence: 单个或多个DNA序列用于滑动窗口预测
data: DataFrame数据必须包含'Long_Sequence''399bp'列,用于区域预测
window_size: 滑动窗口大小默认为3
short_threshold: Short模型(HistGB)概率阈值默认为0.1
ensemble_weight: Short模型在集成中的权重默认为0.4Long权重为0.6
model_dir: 模型文件目录路径(可选)
Returns:
pandas.DataFrame: 预测结果,包含以下主要字段:
- Short_Probability: Short模型预测概率
- Long_Probability: Long模型预测概率
- Ensemble_Probability: 集成预测概率(主要结果)
- Ensemble_Weights: 权重配置信息
Examples:
# 1. 单条序列滑动窗口预测
>>> from FScanpy import predict_prf
>>> sequence = "ATGCGTACGT..."
>>> results = predict_prf(sequence=sequence)
# 2. 多条序列滑动窗口预测
>>> sequences = ["ATGCGTACGT...", "GCTATAGCAT..."]
>>> results = predict_prf(sequence=sequences)
# 3. 自定义集成权重比例
>>> results = predict_prf(sequence=sequence, ensemble_weight=0.3) # 3:7 比例
# 4. DataFrame区域预测
>>> import pandas as pd
>>> data = pd.DataFrame({
... 'Long_Sequence': ['ATGCGT...', 'GCTATAG...'] # 或使用 '399bp'
... })
>>> results = predict_prf(data=data)
"""
predictor = PRFPredictor(model_dir=model_dir)
# 验证输入参数
if sequence is None and data is None:
raise ValueError("必须提供sequence或data参数之一")
if sequence is not None and data is not None:
raise ValueError("sequence和data参数不能同时提供")
if not (0.0 <= ensemble_weight <= 1.0):
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
# 滑动窗口预测模式
if sequence is not None:
if isinstance(sequence, str):
# 单条序列预测
return predictor.predict_sequence(
sequence, window_size, short_threshold, ensemble_weight)
elif isinstance(sequence, (list, tuple)):
# 多条序列预测
results = []
for i, seq in enumerate(sequence, 1):
try:
result = predictor.predict_sequence(
seq, window_size, short_threshold, ensemble_weight)
result['Sequence_ID'] = f'seq_{i}'
results.append(result)
except Exception as e:
print(f"警告:序列 {i} 预测失败 - {str(e)}")
return pd.concat(results, ignore_index=True) if results else pd.DataFrame()
# 区域化预测模式
else:
if not isinstance(data, pd.DataFrame):
raise ValueError("data参数必须是pandas DataFrame类型")
# 检查列名(支持新旧两种命名)
seq_column = None
if 'Long_Sequence' in data.columns:
seq_column = 'Long_Sequence'
elif '399bp' in data.columns:
seq_column = '399bp'
else:
raise ValueError("DataFrame必须包含'Long_Sequence''399bp'")
# 调用区域预测函数
try:
results = predictor.predict_regions(
data[seq_column], short_threshold, ensemble_weight)
# 添加原始数据的其他列
for col in data.columns:
if col not in ['Long_Sequence', '399bp', 'Short_Sequence', '33bp']:
results[col] = data[col].values
return results
except Exception as e:
print(f"警告:区域预测失败 - {str(e)}")
# 创建空结果
long_weight = 1.0 - ensemble_weight
results = pd.DataFrame({
'Short_Probability': [0.0] * len(data),
'Long_Probability': [0.0] * len(data),
'Ensemble_Probability': [0.0] * len(data),
'Ensemble_Weights': [f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}'] * len(data)
})
# 添加原始数据列
for col in data.columns:
results[col] = data[col].values
return results
def plot_prf_prediction(
sequence: str,
window_size: int = 3,
short_threshold: float = 0.65,
long_threshold: float = 0.8,
ensemble_weight: float = 0.4,
title: str = None,
save_path: str = None,
figsize: tuple = (12, 6),
dpi: int = 300,
model_dir: str = None
) -> tuple:
"""
绘制序列PRF预测结果的移码概率图
Args:
sequence: 输入DNA序列
window_size: 滑动窗口大小默认为3
short_threshold: Short模型(HistGB)过滤阈值默认为0.65
long_threshold: Long模型(BiLSTM-CNN)过滤阈值默认为0.8
ensemble_weight: Short模型在集成中的权重默认为0.4Long权重为0.6
title: 图片标题(可选)
save_path: 保存路径(可选,如果提供则保存图片)
figsize: 图片尺寸(默认为(12, 6)
dpi: 图片分辨率默认为300
model_dir: 模型文件目录路径(可选)
Returns:
tuple: (pd.DataFrame, matplotlib.figure.Figure) 预测结果和图形对象
Examples:
# 1. 简单绘图
>>> from FScanpy import plot_prf_prediction
>>> sequence = "ATGCGTACGT..."
>>> results, fig = plot_prf_prediction(sequence)
>>> plt.show()
# 2. 自定义阈值和集成权重
>>> results, fig = plot_prf_prediction(
... sequence,
... short_threshold=0.7,
... long_threshold=0.85,
... ensemble_weight=0.3, # 3:7 权重比例
... title="自定义权重的预测结果",
... save_path="prediction_result.png"
... )
# 3. 等权重组合
>>> results, fig = plot_prf_prediction(
... sequence,
... ensemble_weight=0.5 # 5:5 等权重
... )
# 4. Long模型主导
>>> results, fig = plot_prf_prediction(
... sequence,
... ensemble_weight=0.2 # 2:8 权重Long模型主导
... )
"""
if not (0.0 <= ensemble_weight <= 1.0):
raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间")
predictor = PRFPredictor(model_dir=model_dir)
return predictor.plot_sequence_prediction(
sequence=sequence,
window_size=window_size,
short_threshold=short_threshold,
long_threshold=long_threshold,
ensemble_weight=ensemble_weight,
title=title,
save_path=save_path,
figsize=figsize,
dpi=dpi
)