109 lines
3.8 KiB
Python
109 lines
3.8 KiB
Python
from .predictor import PRFPredictor
|
||
import pandas as pd
|
||
import numpy as np
|
||
from typing import Union, List, Dict
|
||
|
||
__version__ = '0.3.0'
|
||
__author__ = ''
|
||
__email__ = ''
|
||
|
||
__all__ = ['PRFPredictor', 'predict_prf', '__version__', '__author__', '__email__']
|
||
|
||
def predict_prf(
|
||
sequence: Union[str, List[str], None] = None,
|
||
data: Union[pd.DataFrame, None] = None,
|
||
window_size: int = 3,
|
||
gb_threshold: float = 0.1,
|
||
model_dir: str = None
|
||
) -> pd.DataFrame:
|
||
"""
|
||
PRF位点预测函数
|
||
|
||
Args:
|
||
sequence: 单个或多个DNA序列,用于滑动窗口预测
|
||
data: DataFrame数据,必须包含'399bp'列,用于区域预测
|
||
window_size: 滑动窗口大小(默认为3)
|
||
gb_threshold: GB模型概率阈值(默认为0.1)
|
||
model_dir: 模型文件目录路径(可选)
|
||
|
||
Returns:
|
||
pandas.DataFrame: 预测结果
|
||
|
||
Examples:
|
||
# 1. 单条序列滑动窗口预测
|
||
>>> from FScanpy import predict_prf
|
||
>>> sequence = "ATGCGTACGT..."
|
||
>>> results = predict_prf(sequence=sequence)
|
||
|
||
# 2. 多条序列滑动窗口预测
|
||
>>> sequences = ["ATGCGTACGT...", "GCTATAGCAT..."]
|
||
>>> results = predict_prf(sequence=sequences)
|
||
|
||
# 3. DataFrame区域预测
|
||
>>> import pandas as pd
|
||
>>> data = pd.DataFrame({
|
||
... '399bp': ['ATGCGT...', 'GCTATAG...']
|
||
... })
|
||
>>> results = predict_prf(data=data)
|
||
"""
|
||
predictor = PRFPredictor(model_dir=model_dir)
|
||
|
||
# 验证输入参数
|
||
if sequence is None and data is None:
|
||
raise ValueError("必须提供sequence或data参数之一")
|
||
if sequence is not None and data is not None:
|
||
raise ValueError("sequence和data参数不能同时提供")
|
||
|
||
# 滑动窗口预测模式
|
||
if sequence is not None:
|
||
if isinstance(sequence, str):
|
||
# 单条序列预测
|
||
return predictor.predict_full(
|
||
sequence, window_size, gb_threshold)
|
||
elif isinstance(sequence, (list, tuple)):
|
||
# 多条序列预测
|
||
results = []
|
||
for i, seq in enumerate(sequence, 1):
|
||
try:
|
||
result = predictor.predict_full(
|
||
seq, window_size, gb_threshold)
|
||
result['Sequence_ID'] = f'seq_{i}'
|
||
results.append(result)
|
||
except Exception as e:
|
||
print(f"警告:序列 {i} 预测失败 - {str(e)}")
|
||
return pd.concat(results, ignore_index=True) if results else pd.DataFrame()
|
||
|
||
# 区域化预测模式
|
||
else:
|
||
if not isinstance(data, pd.DataFrame):
|
||
raise ValueError("data参数必须是pandas DataFrame类型")
|
||
|
||
if '399bp' not in data.columns:
|
||
raise ValueError("DataFrame必须包含'399bp'列")
|
||
|
||
# 调用区域预测函数
|
||
try:
|
||
results = predictor.predict_region(
|
||
data['399bp'], gb_threshold)
|
||
|
||
# 添加原始数据的其他列
|
||
for col in data.columns:
|
||
if col not in ['399bp', '33bp']:
|
||
results[col] = data[col].values
|
||
|
||
return results
|
||
|
||
except Exception as e:
|
||
print(f"警告:区域预测失败 - {str(e)}")
|
||
# 创建空结果
|
||
results = pd.DataFrame({
|
||
'GB_Probability': [0.0] * len(data),
|
||
'CNN_Probability': [0.0] * len(data),
|
||
'Voting_Probability': [0.0] * len(data)
|
||
})
|
||
|
||
# 添加原始数据列
|
||
for col in data.columns:
|
||
results[col] = data[col].values
|
||
|
||
return results |