from .predictor import PRFPredictor from . import data import pandas as pd import numpy as np from typing import Union, List, Dict __version__ = '0.3.0' __author__ = '' __email__ = '' __all__ = ['PRFPredictor', 'predict_prf', 'plot_prf_prediction', 'data', '__version__', '__author__', '__email__'] def predict_prf( sequence: Union[str, List[str], None] = None, data: Union[pd.DataFrame, None] = None, window_size: int = 3, short_threshold: float = 0.1, ensemble_weight: float = 0.4, model_dir: str = None ) -> pd.DataFrame: """ PRF位点预测函数 Args: sequence: 单个或多个DNA序列,用于滑动窗口预测 data: DataFrame数据,必须包含'Long_Sequence'或'399bp'列,用于区域预测 window_size: 滑动窗口大小(默认为3) short_threshold: Short模型(HistGB)概率阈值(默认为0.1) ensemble_weight: Short模型在集成中的权重(默认为0.4,Long权重为0.6) model_dir: 模型文件目录路径(可选) Returns: pandas.DataFrame: 预测结果,包含以下主要字段: - Short_Probability: Short模型预测概率 - Long_Probability: Long模型预测概率 - Ensemble_Probability: 集成预测概率(主要结果) - Ensemble_Weights: 权重配置信息 Examples: # 1. 单条序列滑动窗口预测 >>> from FScanpy import predict_prf >>> sequence = "ATGCGTACGT..." >>> results = predict_prf(sequence=sequence) # 2. 多条序列滑动窗口预测 >>> sequences = ["ATGCGTACGT...", "GCTATAGCAT..."] >>> results = predict_prf(sequence=sequences) # 3. 自定义集成权重比例 >>> results = predict_prf(sequence=sequence, ensemble_weight=0.3) # 3:7 比例 # 4. DataFrame区域预测 >>> import pandas as pd >>> data = pd.DataFrame({ ... 'Long_Sequence': ['ATGCGT...', 'GCTATAG...'] # 或使用 '399bp' ... }) >>> results = predict_prf(data=data) """ predictor = PRFPredictor(model_dir=model_dir) # 验证输入参数 if sequence is None and data is None: raise ValueError("必须提供sequence或data参数之一") if sequence is not None and data is not None: raise ValueError("sequence和data参数不能同时提供") if not (0.0 <= ensemble_weight <= 1.0): raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间") # 滑动窗口预测模式 if sequence is not None: if isinstance(sequence, str): # 单条序列预测 return predictor.predict_sequence( sequence, window_size, short_threshold, ensemble_weight) elif isinstance(sequence, (list, tuple)): # 多条序列预测 results = [] for i, seq in enumerate(sequence, 1): try: result = predictor.predict_sequence( seq, window_size, short_threshold, ensemble_weight) result['Sequence_ID'] = f'seq_{i}' results.append(result) except Exception as e: print(f"警告:序列 {i} 预测失败 - {str(e)}") return pd.concat(results, ignore_index=True) if results else pd.DataFrame() # 区域化预测模式 else: if not isinstance(data, pd.DataFrame): raise ValueError("data参数必须是pandas DataFrame类型") # 检查列名(支持新旧两种命名) seq_column = None if 'Long_Sequence' in data.columns: seq_column = 'Long_Sequence' elif '399bp' in data.columns: seq_column = '399bp' else: raise ValueError("DataFrame必须包含'Long_Sequence'或'399bp'列") # 调用区域预测函数 try: results = predictor.predict_regions( data[seq_column], short_threshold, ensemble_weight) # 添加原始数据的其他列 for col in data.columns: if col not in ['Long_Sequence', '399bp', 'Short_Sequence', '33bp']: results[col] = data[col].values return results except Exception as e: print(f"警告:区域预测失败 - {str(e)}") # 创建空结果 long_weight = 1.0 - ensemble_weight results = pd.DataFrame({ 'Short_Probability': [0.0] * len(data), 'Long_Probability': [0.0] * len(data), 'Ensemble_Probability': [0.0] * len(data), 'Ensemble_Weights': [f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}'] * len(data) }) # 添加原始数据列 for col in data.columns: results[col] = data[col].values return results def plot_prf_prediction( sequence: str, window_size: int = 3, short_threshold: float = 0.65, long_threshold: float = 0.8, ensemble_weight: float = 0.4, title: str = None, save_path: str = None, figsize: tuple = (12, 6), dpi: int = 300, model_dir: str = None ) -> tuple: """ 绘制序列PRF预测结果的移码概率图 Args: sequence: 输入DNA序列 window_size: 滑动窗口大小(默认为3) short_threshold: Short模型(HistGB)过滤阈值(默认为0.65) long_threshold: Long模型(BiLSTM-CNN)过滤阈值(默认为0.8) ensemble_weight: Short模型在集成中的权重(默认为0.4,Long权重为0.6) title: 图片标题(可选) save_path: 保存路径(可选,如果提供则保存图片) figsize: 图片尺寸(默认为(12, 6)) dpi: 图片分辨率(默认为300) model_dir: 模型文件目录路径(可选) Returns: tuple: (pd.DataFrame, matplotlib.figure.Figure) 预测结果和图形对象 Examples: # 1. 简单绘图 >>> from FScanpy import plot_prf_prediction >>> sequence = "ATGCGTACGT..." >>> results, fig = plot_prf_prediction(sequence) >>> plt.show() # 2. 自定义阈值和集成权重 >>> results, fig = plot_prf_prediction( ... sequence, ... short_threshold=0.7, ... long_threshold=0.85, ... ensemble_weight=0.3, # 3:7 权重比例 ... title="自定义权重的预测结果", ... save_path="prediction_result.png" ... ) # 3. 等权重组合 >>> results, fig = plot_prf_prediction( ... sequence, ... ensemble_weight=0.5 # 5:5 等权重 ... ) # 4. Long模型主导 >>> results, fig = plot_prf_prediction( ... sequence, ... ensemble_weight=0.2 # 2:8 权重,Long模型主导 ... ) """ if not (0.0 <= ensemble_weight <= 1.0): raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间") predictor = PRFPredictor(model_dir=model_dir) return predictor.plot_sequence_prediction( sequence=sequence, window_size=window_size, short_threshold=short_threshold, long_threshold=long_threshold, ensemble_weight=ensemble_weight, title=title, save_path=save_path, figsize=figsize, dpi=dpi )