109 lines
3.8 KiB
Python
109 lines
3.8 KiB
Python
|
|
from .predictor import PRFPredictor
|
|||
|
|
import pandas as pd
|
|||
|
|
import numpy as np
|
|||
|
|
from typing import Union, List, Dict
|
|||
|
|
|
|||
|
|
__version__ = '0.3.0'
|
|||
|
|
__author__ = ''
|
|||
|
|
__email__ = ''
|
|||
|
|
|
|||
|
|
__all__ = ['PRFPredictor', 'predict_prf', '__version__', '__author__', '__email__']
|
|||
|
|
|
|||
|
|
def predict_prf(
|
|||
|
|
sequence: Union[str, List[str], None] = None,
|
|||
|
|
data: Union[pd.DataFrame, None] = None,
|
|||
|
|
window_size: int = 3,
|
|||
|
|
gb_threshold: float = 0.1,
|
|||
|
|
model_dir: str = None
|
|||
|
|
) -> pd.DataFrame:
|
|||
|
|
"""
|
|||
|
|
PRF位点预测函数
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
sequence: 单个或多个DNA序列,用于滑动窗口预测
|
|||
|
|
data: DataFrame数据,必须包含'399bp'列,用于区域预测
|
|||
|
|
window_size: 滑动窗口大小(默认为3)
|
|||
|
|
gb_threshold: GB模型概率阈值(默认为0.1)
|
|||
|
|
model_dir: 模型文件目录路径(可选)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
pandas.DataFrame: 预测结果
|
|||
|
|
|
|||
|
|
Examples:
|
|||
|
|
# 1. 单条序列滑动窗口预测
|
|||
|
|
>>> from FScanpy import predict_prf
|
|||
|
|
>>> sequence = "ATGCGTACGT..."
|
|||
|
|
>>> results = predict_prf(sequence=sequence)
|
|||
|
|
|
|||
|
|
# 2. 多条序列滑动窗口预测
|
|||
|
|
>>> sequences = ["ATGCGTACGT...", "GCTATAGCAT..."]
|
|||
|
|
>>> results = predict_prf(sequence=sequences)
|
|||
|
|
|
|||
|
|
# 3. DataFrame区域预测
|
|||
|
|
>>> import pandas as pd
|
|||
|
|
>>> data = pd.DataFrame({
|
|||
|
|
... '399bp': ['ATGCGT...', 'GCTATAG...']
|
|||
|
|
... })
|
|||
|
|
>>> results = predict_prf(data=data)
|
|||
|
|
"""
|
|||
|
|
predictor = PRFPredictor(model_dir=model_dir)
|
|||
|
|
|
|||
|
|
# 验证输入参数
|
|||
|
|
if sequence is None and data is None:
|
|||
|
|
raise ValueError("必须提供sequence或data参数之一")
|
|||
|
|
if sequence is not None and data is not None:
|
|||
|
|
raise ValueError("sequence和data参数不能同时提供")
|
|||
|
|
|
|||
|
|
# 滑动窗口预测模式
|
|||
|
|
if sequence is not None:
|
|||
|
|
if isinstance(sequence, str):
|
|||
|
|
# 单条序列预测
|
|||
|
|
return predictor.predict_full(
|
|||
|
|
sequence, window_size, gb_threshold)
|
|||
|
|
elif isinstance(sequence, (list, tuple)):
|
|||
|
|
# 多条序列预测
|
|||
|
|
results = []
|
|||
|
|
for i, seq in enumerate(sequence, 1):
|
|||
|
|
try:
|
|||
|
|
result = predictor.predict_full(
|
|||
|
|
seq, window_size, gb_threshold)
|
|||
|
|
result['Sequence_ID'] = f'seq_{i}'
|
|||
|
|
results.append(result)
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"警告:序列 {i} 预测失败 - {str(e)}")
|
|||
|
|
return pd.concat(results, ignore_index=True) if results else pd.DataFrame()
|
|||
|
|
|
|||
|
|
# 区域化预测模式
|
|||
|
|
else:
|
|||
|
|
if not isinstance(data, pd.DataFrame):
|
|||
|
|
raise ValueError("data参数必须是pandas DataFrame类型")
|
|||
|
|
|
|||
|
|
if '399bp' not in data.columns:
|
|||
|
|
raise ValueError("DataFrame必须包含'399bp'列")
|
|||
|
|
|
|||
|
|
# 调用区域预测函数
|
|||
|
|
try:
|
|||
|
|
results = predictor.predict_region(
|
|||
|
|
data['399bp'], gb_threshold)
|
|||
|
|
|
|||
|
|
# 添加原始数据的其他列
|
|||
|
|
for col in data.columns:
|
|||
|
|
if col not in ['399bp', '33bp']:
|
|||
|
|
results[col] = data[col].values
|
|||
|
|
|
|||
|
|
return results
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"警告:区域预测失败 - {str(e)}")
|
|||
|
|
# 创建空结果
|
|||
|
|
results = pd.DataFrame({
|
|||
|
|
'GB_Probability': [0.0] * len(data),
|
|||
|
|
'CNN_Probability': [0.0] * len(data),
|
|||
|
|
'Voting_Probability': [0.0] * len(data)
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
# 添加原始数据列
|
|||
|
|
for col in data.columns:
|
|||
|
|
results[col] = data[col].values
|
|||
|
|
|
|||
|
|
return results
|