From 089df9c4a616a18cd5cb6ca589c755e64cbee07b Mon Sep 17 00:00:00 2001 From: Chenlab Date: Wed, 11 Jun 2025 21:18:52 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E7=BB=98=E5=9B=BE=EF=BC=8C?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BD=BF=E7=94=A8=E4=BB=8B=E7=BB=8D=E7=9A=84?= =?UTF-8?q?ipynb?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- FScanpy/__init__.py | 114 ++--- FScanpy/predictor.py | 984 ++++++++++++++++++++++--------------------- FScanpy_Demo.ipynb | 281 ++++++------ 3 files changed, 692 insertions(+), 687 deletions(-) diff --git a/FScanpy/__init__.py b/FScanpy/__init__.py index a30b0ad..4962d6a 100644 --- a/FScanpy/__init__.py +++ b/FScanpy/__init__.py @@ -19,61 +19,61 @@ def predict_prf( model_dir: str = None ) -> pd.DataFrame: """ - PRF位点预测函数 + PRF site prediction function Args: - sequence: 单个或多个DNA序列,用于滑动窗口预测 - data: DataFrame数据,必须包含'Long_Sequence'或'399bp'列,用于区域预测 - window_size: 滑动窗口大小(默认为3) - short_threshold: Short模型(HistGB)概率阈值(默认为0.1) - ensemble_weight: Short模型在集成中的权重(默认为0.4,Long权重为0.6) - model_dir: 模型文件目录路径(可选) + sequence: Single or multiple DNA sequences for sliding window prediction + data: DataFrame data, must contain 'Long_Sequence' or '399bp' column for region prediction + window_size: Sliding window size (default: 3) + short_threshold: Short model (HistGB) probability threshold (default: 0.1) + ensemble_weight: Weight of short model in ensemble (default: 0.4, long weight: 0.6) + model_dir: Model directory path (optional) Returns: - pandas.DataFrame: 预测结果,包含以下主要字段: - - Short_Probability: Short模型预测概率 - - Long_Probability: Long模型预测概率 - - Ensemble_Probability: 集成预测概率(主要结果) - - Ensemble_Weights: 权重配置信息 + pandas.DataFrame: Prediction results containing the following main fields: + - Short_Probability: Short model prediction probability + - Long_Probability: Long model prediction probability + - Ensemble_Probability: Ensemble prediction probability (main result) + - Ensemble_Weights: Weight configuration information Examples: - # 1. 单条序列滑动窗口预测 + # 1. Single sequence sliding window prediction >>> from FScanpy import predict_prf >>> sequence = "ATGCGTACGT..." >>> results = predict_prf(sequence=sequence) - # 2. 多条序列滑动窗口预测 + # 2. Multiple sequences sliding window prediction >>> sequences = ["ATGCGTACGT...", "GCTATAGCAT..."] >>> results = predict_prf(sequence=sequences) - # 3. 自定义集成权重比例 - >>> results = predict_prf(sequence=sequence, ensemble_weight=0.3) # 3:7 比例 + # 3. Custom ensemble weight ratio + >>> results = predict_prf(sequence=sequence, ensemble_weight=0.3) # 3:7 ratio - # 4. DataFrame区域预测 + # 4. DataFrame region prediction >>> import pandas as pd >>> data = pd.DataFrame({ - ... 'Long_Sequence': ['ATGCGT...', 'GCTATAG...'] # 或使用 '399bp' + ... 'Long_Sequence': ['ATGCGT...', 'GCTATAG...'] # or use '399bp' ... }) >>> results = predict_prf(data=data) """ predictor = PRFPredictor(model_dir=model_dir) - # 验证输入参数 + # Validate input parameters if sequence is None and data is None: - raise ValueError("必须提供sequence或data参数之一") + raise ValueError("Must provide either sequence or data parameter") if sequence is not None and data is not None: - raise ValueError("sequence和data参数不能同时提供") + raise ValueError("Cannot provide both sequence and data parameters") if not (0.0 <= ensemble_weight <= 1.0): - raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间") + raise ValueError("ensemble_weight must be between 0.0 and 1.0") - # 滑动窗口预测模式 + # Sliding window prediction mode if sequence is not None: if isinstance(sequence, str): - # 单条序列预测 + # Single sequence prediction return predictor.predict_sequence( sequence, window_size, short_threshold, ensemble_weight) elif isinstance(sequence, (list, tuple)): - # 多条序列预测 + # Multiple sequences prediction results = [] for i, seq in enumerate(sequence, 1): try: @@ -82,29 +82,29 @@ def predict_prf( result['Sequence_ID'] = f'seq_{i}' results.append(result) except Exception as e: - print(f"警告:序列 {i} 预测失败 - {str(e)}") + print(f"Warning: Sequence {i} prediction failed - {str(e)}") return pd.concat(results, ignore_index=True) if results else pd.DataFrame() - # 区域化预测模式 + # Region prediction mode else: if not isinstance(data, pd.DataFrame): - raise ValueError("data参数必须是pandas DataFrame类型") + raise ValueError("data parameter must be pandas DataFrame type") - # 检查列名(支持新旧两种命名) + # Check column names (support both new and old naming conventions) seq_column = None if 'Long_Sequence' in data.columns: seq_column = 'Long_Sequence' elif '399bp' in data.columns: seq_column = '399bp' else: - raise ValueError("DataFrame必须包含'Long_Sequence'或'399bp'列") + raise ValueError("DataFrame must contain 'Long_Sequence' or '399bp' column") - # 调用区域预测函数 + # Call region prediction function try: results = predictor.predict_regions( data[seq_column], short_threshold, ensemble_weight) - # 添加原始数据的其他列 + # Add other columns from original data for col in data.columns: if col not in ['Long_Sequence', '399bp', 'Short_Sequence', '33bp']: results[col] = data[col].values @@ -112,8 +112,8 @@ def predict_prf( return results except Exception as e: - print(f"警告:区域预测失败 - {str(e)}") - # 创建空结果 + print(f"Warning: Region prediction failed - {str(e)}") + # Create empty results long_weight = 1.0 - ensemble_weight results = pd.DataFrame({ 'Short_Probability': [0.0] * len(data), @@ -122,7 +122,7 @@ def predict_prf( 'Ensemble_Weights': [f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}'] * len(data) }) - # 添加原始数据列 + # Add original data columns for col in data.columns: results[col] = data[col].values @@ -136,59 +136,59 @@ def plot_prf_prediction( ensemble_weight: float = 0.4, title: str = None, save_path: str = None, - figsize: tuple = (12, 6), + figsize: tuple = (12, 8), dpi: int = 300, model_dir: str = None ) -> tuple: """ - 绘制序列PRF预测结果的移码概率图 + Plot PRF prediction results for sequence frameshifting probability Args: - sequence: 输入DNA序列 - window_size: 滑动窗口大小(默认为3) - short_threshold: Short模型(HistGB)过滤阈值(默认为0.65) - long_threshold: Long模型(BiLSTM-CNN)过滤阈值(默认为0.8) - ensemble_weight: Short模型在集成中的权重(默认为0.4,Long权重为0.6) - title: 图片标题(可选) - save_path: 保存路径(可选,如果提供则保存图片) - figsize: 图片尺寸(默认为(12, 6)) - dpi: 图片分辨率(默认为300) - model_dir: 模型文件目录路径(可选) + sequence: Input DNA sequence + window_size: Sliding window size (default: 3) + short_threshold: Short model (HistGB) filtering threshold (default: 0.65) + long_threshold: Long model (BiLSTM-CNN) filtering threshold (default: 0.8) + ensemble_weight: Weight of short model in ensemble (default: 0.4, long weight: 0.6) + title: Plot title (optional) + save_path: Save path (optional, saves plot if provided) + figsize: Figure size (default: (12, 8)) + dpi: Figure resolution (default: 300) + model_dir: Model directory path (optional) Returns: - tuple: (pd.DataFrame, matplotlib.figure.Figure) 预测结果和图形对象 + tuple: (pd.DataFrame, matplotlib.figure.Figure) prediction results and figure object Examples: - # 1. 简单绘图 + # 1. Simple plotting >>> from FScanpy import plot_prf_prediction >>> sequence = "ATGCGTACGT..." >>> results, fig = plot_prf_prediction(sequence) >>> plt.show() - # 2. 自定义阈值和集成权重 + # 2. Custom thresholds and ensemble weights >>> results, fig = plot_prf_prediction( ... sequence, ... short_threshold=0.7, ... long_threshold=0.85, - ... ensemble_weight=0.3, # 3:7 权重比例 - ... title="自定义权重的预测结果", + ... ensemble_weight=0.3, # 3:7 weight ratio + ... title="Custom Weight Prediction Results", ... save_path="prediction_result.png" ... ) - # 3. 等权重组合 + # 3. Equal weight combination >>> results, fig = plot_prf_prediction( ... sequence, - ... ensemble_weight=0.5 # 5:5 等权重 + ... ensemble_weight=0.5 # 5:5 equal weights ... ) - # 4. Long模型主导 + # 4. Long model dominated >>> results, fig = plot_prf_prediction( ... sequence, - ... ensemble_weight=0.2 # 2:8 权重,Long模型主导 + ... ensemble_weight=0.2 # 2:8 weights, long model dominated ... ) """ if not (0.0 <= ensemble_weight <= 1.0): - raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间") + raise ValueError("ensemble_weight must be between 0.0 and 1.0") predictor = PRFPredictor(model_dir=model_dir) diff --git a/FScanpy/predictor.py b/FScanpy/predictor.py index 27a352b..09ea283 100644 --- a/FScanpy/predictor.py +++ b/FScanpy/predictor.py @@ -1,487 +1,499 @@ -import os -import pickle -import numpy as np -import pandas as pd -from tensorflow.keras.models import load_model -from .features.sequence import SequenceFeatureExtractor -from .features.cnn_input import CNNInputProcessor -from .utils import extract_window_sequences -import matplotlib.pyplot as plt -import joblib - - -class PRFPredictor: - - def __init__(self, model_dir=None): - """ - 初始化PRF预测器 - - Args: - model_dir: 模型目录路径(可选) - """ - if model_dir is None: - from pkg_resources import resource_filename - model_dir = resource_filename('FScanpy', 'pretrained') - - try: - # 加载模型 - 使用新的命名约定 - self.short_model = self._load_pickle(os.path.join(model_dir, 'short.pkl')) # HistGB模型 - self.long_model = self._load_pickle(os.path.join(model_dir, 'long.pkl')) # BiLSTM-CNN模型 - - # 初始化特征提取器和CNN处理器,使用与训练时相同的序列长度 - self.short_seq_length = 33 # HistGB使用的序列长度 - self.long_seq_length = 399 # BiLSTM-CNN使用的序列长度 - - # 初始化特征提取器和CNN输入处理器 - self.feature_extractor = SequenceFeatureExtractor(seq_length=self.short_seq_length) - self.cnn_processor = CNNInputProcessor(max_length=self.long_seq_length) - - # 检测模型类型以优化预测性能 - self._detect_model_types() - - except FileNotFoundError as e: - raise FileNotFoundError(f"无法找到模型文件: {str(e)}。请确保模型文件 'short.pkl' 和 'long.pkl' 存在于 {model_dir}") - except Exception as e: - raise Exception(f"加载模型出错: {str(e)}") - - def _load_pickle(self, path): - """安全加载pickle文件""" - try: - return joblib.load(path) - except Exception as e: - raise FileNotFoundError(f"无法加载模型文件 {path}: {str(e)}") - - def _detect_model_types(self): - """检测模型类型以优化预测性能""" - self.short_is_sklearn = hasattr(self.short_model, 'predict_proba') - self.long_is_sklearn = hasattr(self.long_model, 'predict_proba') - - def _predict_model(self, model, features, is_sklearn, seq_length): - """统一的模型预测方法""" - try: - if is_sklearn: - # sklearn模型使用特征向量 - if isinstance(features, np.ndarray) and features.ndim > 1: - features = features.flatten() - features_2d = np.array([features]) - pred = model.predict_proba(features_2d) - return pred[0][1] - else: - # 深度学习模型 - if seq_length == self.long_seq_length: - # 对于长序列,使用CNN处理器 - model_input = self.cnn_processor.prepare_sequence(features) - else: - # 对于短序列,转换为数值编码 - base_to_num = {'A': 1, 'T': 2, 'G': 3, 'C': 4, 'N': 0} - seq_numeric = [base_to_num.get(base, 0) for base in features.upper()] - model_input = np.array(seq_numeric).reshape(1, len(seq_numeric), 1) - - # 统一的预测调用 - try: - pred = model.predict(model_input, verbose=0) - except TypeError: - pred = model.predict(model_input) - - # 处理预测结果 - if isinstance(pred, list): - pred = pred[0] - - if hasattr(pred, 'shape') and len(pred.shape) > 1 and pred.shape[1] > 1: - return pred[0][1] - else: - return pred[0][0] if hasattr(pred[0], '__getitem__') else pred[0] - - except Exception as e: - raise Exception(f"模型预测失败: {str(e)}") - - def predict_single_position(self, fs_period, full_seq, short_threshold=0.1, ensemble_weight=0.4): - ''' - 预测单个位置的PRF状态 - - Args: - fs_period: 33bp序列 (short模型使用) - full_seq: 完整序列 (long模型使用) - short_threshold: short模型的概率阈值 (默认为0.1) - ensemble_weight: short模型在集成中的权重 (默认为0.4,long权重为0.6) - Returns: - dict: 包含预测概率的字典 - ''' - try: - # 验证权重参数 - if not (0.0 <= ensemble_weight <= 1.0): - raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间") - - long_weight = 1.0 - ensemble_weight - - # 处理序列长度 - if len(fs_period) > self.short_seq_length: - fs_period = self.feature_extractor.trim_sequence(fs_period, self.short_seq_length) - - # Short模型预测 (HistGB) - try: - if self.short_is_sklearn: - short_features = self.feature_extractor.extract_features(fs_period) - short_prob = self._predict_model(self.short_model, short_features, True, self.short_seq_length) - else: - short_prob = self._predict_model(self.short_model, fs_period, False, self.short_seq_length) - except Exception as e: - print(f"Short模型预测时出错: {str(e)}") - short_prob = 0.0 - - # 如果short概率低于阈值,则跳过long模型 - if short_prob < short_threshold: - return { - 'Short_Probability': short_prob, - 'Long_Probability': 0.0, - 'Ensemble_Probability': 0.0, - 'Ensemble_Weights': f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}' - } - - # Long模型预测 (BiLSTM-CNN) - try: - if self.long_is_sklearn: - long_features = self.feature_extractor.extract_features(full_seq) - long_prob = self._predict_model(self.long_model, long_features, True, self.long_seq_length) - else: - long_prob = self._predict_model(self.long_model, full_seq, False, self.long_seq_length) - except Exception as e: - print(f"Long模型预测时出错: {str(e)}") - long_prob = 0.0 - - # 计算集成概率 - try: - ensemble_prob = ensemble_weight * short_prob + long_weight * long_prob - except Exception as e: - print(f"计算集成概率时出错: {str(e)}") - ensemble_prob = (short_prob + long_prob) / 2 - - return { - 'Short_Probability': short_prob, - 'Long_Probability': long_prob, - 'Ensemble_Probability': ensemble_prob, - 'Ensemble_Weights': f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}' - } - - except Exception as e: - raise Exception(f"预测过程出错: {str(e)}") - - def predict_sequence(self, sequence, window_size=3, short_threshold=0.1, ensemble_weight=0.4): - """ - 预测完整序列中的PRF位点(滑动窗口方法) - - Args: - sequence: 输入DNA序列 - window_size: 滑动窗口大小 (默认为3) - short_threshold: short模型概率阈值 (默认为0.1) - ensemble_weight: short模型在集成中的权重 (默认为0.4) - - Returns: - pd.DataFrame: 包含预测结果的DataFrame - """ - if window_size < 1: - raise ValueError("窗口大小必须大于等于1") - if short_threshold < 0: - raise ValueError("short模型阈值必须大于等于0") - if not (0.0 <= ensemble_weight <= 1.0): - raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间") - - results = [] - long_weight = 1.0 - ensemble_weight - - try: - # 确保序列为字符串并转换为大写 - sequence = str(sequence).upper() - - # 滑动窗口预测 - for pos in range(0, len(sequence) - 2, window_size): - # 提取窗口序列 - fs_period, full_seq = extract_window_sequences(sequence, pos) - - if fs_period is None or full_seq is None: - continue - - # 预测并记录结果 - pred = self.predict_single_position(fs_period, full_seq, short_threshold, ensemble_weight) - pred.update({ - 'Position': pos, - 'Codon': sequence[pos:pos+3], - 'Short_Sequence': fs_period, # 更清晰的命名 - 'Long_Sequence': full_seq # 更清晰的命名 - }) - results.append(pred) - - # 创建结果DataFrame - results_df = pd.DataFrame(results) - - return results_df - - except Exception as e: - raise Exception(f"序列预测过程出错: {str(e)}") - - def plot_sequence_prediction(self, sequence, window_size=3, short_threshold=0.65, - long_threshold=0.8, ensemble_weight=0.4, title=None, save_path=None, - figsize=(12, 6), dpi=300): - """ - 绘制序列预测结果的移码概率图 - - Args: - sequence: 输入DNA序列 - window_size: 滑动窗口大小 (默认为3) - short_threshold: Short模型(HistGB)过滤阈值 (默认为0.65) - long_threshold: Long模型(BiLSTM-CNN)过滤阈值 (默认为0.8) - ensemble_weight: Short模型在集成中的权重 (默认为0.4) - title: 图片标题 (可选) - save_path: 保存路径 (可选,如果提供则保存图片) - figsize: 图片尺寸 (默认为(12, 6)) - dpi: 图片分辨率 (默认为300) - - Returns: - tuple: (pd.DataFrame, matplotlib.figure.Figure) 预测结果和图形对象 - """ - try: - # 验证权重参数 - if not (0.0 <= ensemble_weight <= 1.0): - raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间") - - long_weight = 1.0 - ensemble_weight - - # 获取预测结果 - 使用新的方法名 - results_df = self.predict_sequence(sequence, window_size=window_size, - short_threshold=0.1, ensemble_weight=ensemble_weight) - - if results_df.empty: - raise ValueError("预测结果为空,请检查输入序列") - - # 获取序列长度 - seq_length = len(sequence) - - # 计算显示宽度 - prob_width = max(1, seq_length // 300) # 概率标记的宽度 - - # 创建图形,包含两个子图 - fig = plt.figure(figsize=figsize) - gs = fig.add_gridspec(2, 1, height_ratios=[0.15, 1], hspace=0.3) - - # 设置标题 - if title: - fig.suptitle(title, y=0.95, fontsize=12) - else: - fig.suptitle(f'序列移码概率预测结果 (权重 {ensemble_weight:.1f}:{long_weight:.1f})', y=0.95, fontsize=12) - - # 预测概率热图 - ax0 = fig.add_subplot(gs[0]) - prob_data = np.zeros((1, seq_length)) - - # 应用双重阈值过滤 - for _, row in results_df.iterrows(): - pos = int(row['Position']) - if (row['Short_Probability'] >= short_threshold and - row['Long_Probability'] >= long_threshold): - # 为每个满足阈值的位置设置概率值 - start = max(0, pos - prob_width//2) - end = min(seq_length, pos + prob_width//2 + 1) - prob_data[0, start:end] = row['Ensemble_Probability'] - - im = ax0.imshow(prob_data, cmap='Reds', aspect='auto', vmin=0, vmax=1, - interpolation='nearest') - ax0.set_xticks([]) - ax0.set_yticks([]) - ax0.set_title(f'预测概率热图 (Short≥{short_threshold}, Long≥{long_threshold})', - pad=5, fontsize=10) - - # 主图(条形图) - ax1 = fig.add_subplot(gs[1]) - - # 应用过滤阈值 - filtered_probs = results_df['Ensemble_Probability'].copy() - mask = ((results_df['Short_Probability'] < short_threshold) | - (results_df['Long_Probability'] < long_threshold)) - filtered_probs[mask] = 0 - - # 绘制条形图 - bars = ax1.bar(results_df['Position'], filtered_probs, - alpha=0.7, color='darkred', width=max(1, window_size)) - - # 设置x轴刻度 - step = max(seq_length // 10, 50) - x_ticks = np.arange(0, seq_length, step) - ax1.set_xticks(x_ticks) - ax1.tick_params(axis='x', rotation=45) - - # 设置标签和标题 - ax1.set_xlabel('序列位置 (bp)', fontsize=10) - ax1.set_ylabel('移码概率', fontsize=10) - ax1.set_title(f'移码概率分布 (集成权重 {ensemble_weight:.1f}:{long_weight:.1f})', fontsize=11) - - # 设置y轴范围 - ax1.set_ylim(0, 1) - - # 添加网格 - ax1.grid(True, alpha=0.3) - - # 添加阈值和权重说明 - info_text = (f'过滤阈值: Short≥{short_threshold}, Long≥{long_threshold}\n' - f'集成权重: Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}') - ax1.text(0.02, 0.95, info_text, transform=ax1.transAxes, - fontsize=9, verticalalignment='top', - bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8)) - - # 确保所有子图的x轴范围一致 - for ax in [ax0, ax1]: - ax.set_xlim(-1, seq_length) - - # 调整布局 - plt.tight_layout() - - # 如果提供了保存路径,则保存图片 - if save_path: - plt.savefig(save_path, dpi=dpi, bbox_inches='tight') - # 同时保存PDF版本 - if save_path.endswith('.png'): - pdf_path = save_path.replace('.png', '.pdf') - plt.savefig(pdf_path, bbox_inches='tight') - print(f"图片已保存至: {save_path}") - - return results_df, fig - - except Exception as e: - raise Exception(f"绘制序列预测图时出错: {str(e)}") - - def predict_regions(self, sequences, short_threshold=0.1, ensemble_weight=0.4): - ''' - 预测区域序列(批量预测已知的399bp序列) - - Args: - sequences: 399bp序列或包含399bp序列的DataFrame/Series/list - short_threshold: short模型概率阈值 (默认为0.1) - ensemble_weight: short模型在集成中的权重 (默认为0.4) - - Returns: - DataFrame: 包含所有序列预测概率的DataFrame - ''' - try: - # 验证权重参数 - if not (0.0 <= ensemble_weight <= 1.0): - raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间") - - # 统一输入格式 - if isinstance(sequences, (pd.DataFrame, pd.Series)): - sequences = sequences.tolist() - elif isinstance(sequences, str): - sequences = [sequences] - - results = [] - for i, seq399 in enumerate(sequences): - try: - # 从399bp序列中截取中心的33bp (short模型使用) - seq33 = self._extract_center_sequence(seq399, target_length=self.short_seq_length) - - # 使用统一的预测方法 - pred_result = self.predict_single_position(seq33, seq399, short_threshold, ensemble_weight) - pred_result.update({ - 'Short_Sequence': seq33, - 'Long_Sequence': seq399 - }) - - results.append(pred_result) - - except Exception as e: - print(f"处理第 {i+1} 个序列时出错: {str(e)}") - long_weight = 1.0 - ensemble_weight - results.append({ - 'Short_Probability': 0.0, - 'Long_Probability': 0.0, - 'Ensemble_Probability': 0.0, - 'Ensemble_Weights': f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}', - 'Short_Sequence': self._extract_center_sequence(seq399, target_length=self.short_seq_length) if len(seq399) >= self.short_seq_length else seq399, - 'Long_Sequence': seq399 - }) - - return pd.DataFrame(results) - - except Exception as e: - raise Exception(f"区域预测过程出错: {str(e)}") - - def _extract_center_sequence(self, sequence, target_length=33): - """从序列中心位置提取指定长度的子序列""" - # 确保序列为字符串 - sequence = str(sequence).upper() - - # 如果序列长度小于目标长度,返回原序列 - if len(sequence) <= target_length: - return sequence - - # 计算中心位置 - center = len(sequence) // 2 - half_target = target_length // 2 - - # 提取中心序列 - start = center - half_target - end = start + target_length - - # 边界检查 - if start < 0: - start = 0 - end = target_length - elif end > len(sequence): - end = len(sequence) - start = end - target_length - - return sequence[start:end] - - # 兼容性方法(向后兼容,但标记为废弃) - def predict_full(self, sequence, window_size=3, short_threshold=0.1, short_weight=0.4, plot=False): - """ - ⚠️ 已废弃:请使用 predict_sequence() 方法 - - 向后兼容的方法,内部调用新的 predict_sequence() - """ - import warnings - warnings.warn("predict_full() 已废弃,请使用 predict_sequence() 方法", DeprecationWarning, stacklevel=2) - - # 调用新方法并添加兼容性字段 - results_df = self.predict_sequence(sequence, window_size, short_threshold, short_weight) - - # 添加兼容性字段 - if 'Ensemble_Probability' in results_df.columns: - results_df['Voting_Probability'] = results_df['Ensemble_Probability'] - results_df['Weighted_Probability'] = results_df['Ensemble_Probability'] - if 'Ensemble_Weights' in results_df.columns: - results_df['Weight_Info'] = results_df['Ensemble_Weights'] - if 'Short_Sequence' in results_df.columns: - results_df['33bp'] = results_df['Short_Sequence'] - if 'Long_Sequence' in results_df.columns: - results_df['399bp'] = results_df['Long_Sequence'] - - if plot: - # 如果需要绘图,调用绘图方法 - _, fig = self.plot_sequence_prediction(sequence, window_size, 0.65, 0.8, short_weight) - return results_df, fig - - return results_df - - def predict_region(self, seq, short_threshold=0.1, short_weight=0.4): - """ - ⚠️ 已废弃:请使用 predict_regions() 方法 - - 向后兼容的方法,内部调用新的 predict_regions() - """ - import warnings - warnings.warn("predict_region() 已废弃,请使用 predict_regions() 方法", DeprecationWarning, stacklevel=2) - - # 调用新方法并添加兼容性字段 - results_df = self.predict_regions(seq, short_threshold, short_weight) - - # 添加兼容性字段 - if 'Ensemble_Probability' in results_df.columns: - results_df['Voting_Probability'] = results_df['Ensemble_Probability'] - results_df['Weighted_Probability'] = results_df['Ensemble_Probability'] - if 'Ensemble_Weights' in results_df.columns: - results_df['Weight_Info'] = results_df['Ensemble_Weights'] - if 'Short_Sequence' in results_df.columns: - results_df['33bp'] = results_df['Short_Sequence'] - if 'Long_Sequence' in results_df.columns: - results_df['399bp'] = results_df['Long_Sequence'] - +import os +import pickle +import numpy as np +import pandas as pd +from tensorflow.keras.models import load_model +from .features.sequence import SequenceFeatureExtractor +from .features.cnn_input import CNNInputProcessor +from .utils import extract_window_sequences +import matplotlib.pyplot as plt +import joblib + + +class PRFPredictor: + + def __init__(self, model_dir=None): + """ + 初始化PRF预测器 + + Args: + model_dir: 模型目录路径(可选) + """ + if model_dir is None: + from pkg_resources import resource_filename + model_dir = resource_filename('FScanpy', 'pretrained') + + try: + # 加载模型 - 使用新的命名约定 + self.short_model = self._load_pickle(os.path.join(model_dir, 'short.pkl')) # HistGB模型 + self.long_model = self._load_pickle(os.path.join(model_dir, 'long.pkl')) # BiLSTM-CNN模型 + + # 初始化特征提取器和CNN处理器,使用与训练时相同的序列长度 + self.short_seq_length = 33 # HistGB使用的序列长度 + self.long_seq_length = 399 # BiLSTM-CNN使用的序列长度 + + # 初始化特征提取器和CNN输入处理器 + self.feature_extractor = SequenceFeatureExtractor(seq_length=self.short_seq_length) + self.cnn_processor = CNNInputProcessor(max_length=self.long_seq_length) + + # 检测模型类型以优化预测性能 + self._detect_model_types() + + except FileNotFoundError as e: + raise FileNotFoundError(f"无法找到模型文件: {str(e)}。请确保模型文件 'short.pkl' 和 'long.pkl' 存在于 {model_dir}") + except Exception as e: + raise Exception(f"加载模型出错: {str(e)}") + + def _load_pickle(self, path): + """安全加载pickle文件""" + try: + return joblib.load(path) + except Exception as e: + raise FileNotFoundError(f"无法加载模型文件 {path}: {str(e)}") + + def _detect_model_types(self): + """检测模型类型以优化预测性能""" + self.short_is_sklearn = hasattr(self.short_model, 'predict_proba') + self.long_is_sklearn = hasattr(self.long_model, 'predict_proba') + + def _predict_model(self, model, features, is_sklearn, seq_length): + """统一的模型预测方法""" + try: + if is_sklearn: + # sklearn模型使用特征向量 + if isinstance(features, np.ndarray) and features.ndim > 1: + features = features.flatten() + features_2d = np.array([features]) + pred = model.predict_proba(features_2d) + return pred[0][1] + else: + # 深度学习模型 + if seq_length == self.long_seq_length: + # 对于长序列,使用CNN处理器 + model_input = self.cnn_processor.prepare_sequence(features) + else: + # 对于短序列,转换为数值编码 + base_to_num = {'A': 1, 'T': 2, 'G': 3, 'C': 4, 'N': 0} + seq_numeric = [base_to_num.get(base, 0) for base in features.upper()] + model_input = np.array(seq_numeric).reshape(1, len(seq_numeric), 1) + + # 统一的预测调用 + try: + pred = model.predict(model_input, verbose=0) + except TypeError: + pred = model.predict(model_input) + + # 处理预测结果 + if isinstance(pred, list): + pred = pred[0] + + if hasattr(pred, 'shape') and len(pred.shape) > 1 and pred.shape[1] > 1: + return pred[0][1] + else: + return pred[0][0] if hasattr(pred[0], '__getitem__') else pred[0] + + except Exception as e: + raise Exception(f"模型预测失败: {str(e)}") + + def predict_single_position(self, fs_period, full_seq, short_threshold=0.1, ensemble_weight=0.4): + ''' + 预测单个位置的PRF状态 + + Args: + fs_period: 33bp序列 (short模型使用) + full_seq: 完整序列 (long模型使用) + short_threshold: short模型的概率阈值 (默认为0.1) + ensemble_weight: short模型在集成中的权重 (默认为0.4,long权重为0.6) + Returns: + dict: 包含预测概率的字典 + ''' + try: + # 验证权重参数 + if not (0.0 <= ensemble_weight <= 1.0): + raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间") + + long_weight = 1.0 - ensemble_weight + + # 处理序列长度 + if len(fs_period) > self.short_seq_length: + fs_period = self.feature_extractor.trim_sequence(fs_period, self.short_seq_length) + + # Short模型预测 (HistGB) + try: + if self.short_is_sklearn: + short_features = self.feature_extractor.extract_features(fs_period) + short_prob = self._predict_model(self.short_model, short_features, True, self.short_seq_length) + else: + short_prob = self._predict_model(self.short_model, fs_period, False, self.short_seq_length) + except Exception as e: + print(f"Short模型预测时出错: {str(e)}") + short_prob = 0.0 + + # 如果short概率低于阈值,则跳过long模型 + if short_prob < short_threshold: + return { + 'Short_Probability': short_prob, + 'Long_Probability': 0.0, + 'Ensemble_Probability': 0.0, + 'Ensemble_Weights': f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}' + } + + # Long模型预测 (BiLSTM-CNN) + try: + if self.long_is_sklearn: + long_features = self.feature_extractor.extract_features(full_seq) + long_prob = self._predict_model(self.long_model, long_features, True, self.long_seq_length) + else: + long_prob = self._predict_model(self.long_model, full_seq, False, self.long_seq_length) + except Exception as e: + print(f"Long模型预测时出错: {str(e)}") + long_prob = 0.0 + + # 计算集成概率 + try: + ensemble_prob = ensemble_weight * short_prob + long_weight * long_prob + except Exception as e: + print(f"计算集成概率时出错: {str(e)}") + ensemble_prob = (short_prob + long_prob) / 2 + + return { + 'Short_Probability': short_prob, + 'Long_Probability': long_prob, + 'Ensemble_Probability': ensemble_prob, + 'Ensemble_Weights': f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}' + } + + except Exception as e: + raise Exception(f"预测过程出错: {str(e)}") + + def predict_sequence(self, sequence, window_size=3, short_threshold=0.1, ensemble_weight=0.4): + """ + 预测完整序列中的PRF位点(滑动窗口方法) + + Args: + sequence: 输入DNA序列 + window_size: 滑动窗口大小 (默认为3) + short_threshold: short模型概率阈值 (默认为0.1) + ensemble_weight: short模型在集成中的权重 (默认为0.4) + + Returns: + pd.DataFrame: 包含预测结果的DataFrame + """ + if window_size < 1: + raise ValueError("窗口大小必须大于等于1") + if short_threshold < 0: + raise ValueError("short模型阈值必须大于等于0") + if not (0.0 <= ensemble_weight <= 1.0): + raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间") + + results = [] + long_weight = 1.0 - ensemble_weight + + try: + # 确保序列为字符串并转换为大写 + sequence = str(sequence).upper() + + # 滑动窗口预测 + for pos in range(0, len(sequence) - 2, window_size): + # 提取窗口序列 + fs_period, full_seq = extract_window_sequences(sequence, pos) + + if fs_period is None or full_seq is None: + continue + + # 预测并记录结果 + pred = self.predict_single_position(fs_period, full_seq, short_threshold, ensemble_weight) + pred.update({ + 'Position': pos, + 'Codon': sequence[pos:pos+3], + 'Short_Sequence': fs_period, # 更清晰的命名 + 'Long_Sequence': full_seq # 更清晰的命名 + }) + results.append(pred) + + # 创建结果DataFrame + results_df = pd.DataFrame(results) + + return results_df + + except Exception as e: + raise Exception(f"序列预测过程出错: {str(e)}") + + def plot_sequence_prediction(self, sequence, window_size=3, short_threshold=0.65, + long_threshold=0.8, ensemble_weight=0.4, title=None, save_path=None, + figsize=(12, 8), dpi=300): + """ + Plot sequence PRF prediction results + + Args: + sequence: Input DNA sequence + window_size: Sliding window size (default: 3) + short_threshold: Short model (HistGB) filtering threshold (default: 0.65) + long_threshold: Long model (BiLSTM-CNN) filtering threshold (default: 0.8) + ensemble_weight: Weight of short model in ensemble (default: 0.4) + title: Plot title (optional) + save_path: Save path (optional, saves plot if provided) + figsize: Figure size (default: (12, 8)) + dpi: Figure resolution (default: 300) + + Returns: + tuple: (pd.DataFrame, matplotlib.figure.Figure) prediction results and figure object + """ + try: + # Validate weight parameter + if not (0.0 <= ensemble_weight <= 1.0): + raise ValueError("ensemble_weight must be between 0.0 and 1.0") + + long_weight = 1.0 - ensemble_weight + + # Get prediction results + results_df = self.predict_sequence(sequence, window_size=window_size, + short_threshold=0.1, ensemble_weight=ensemble_weight) + + if results_df.empty: + raise ValueError("Prediction results are empty, please check input sequence") + + # Get sequence length + seq_length = len(sequence) + + # Calculate display width + desired_visual_width = max(3, seq_length // 100) # FS site width ~1% of sequence length + prob_width = max(1, desired_visual_width // 3) # Prediction probability width is 1/3 of FS site width + + # Create figure with three subplots, set height ratios + fig = plt.figure(figsize=figsize) + + # Set title + if title: + fig.suptitle(title, y=0.95, fontsize=10) + else: + fig.suptitle(f'PRF Prediction Results (Weights {ensemble_weight:.1f}:{long_weight:.1f})', y=0.95, fontsize=10) + + # Adjust subplot ratios, make top two heatmaps smaller + gs = fig.add_gridspec(3, 1, height_ratios=[0.1, 0.1, 1], hspace=0.2) + + # FS site heatmap - using fixed width, no blur effect + ax0 = fig.add_subplot(gs[0]) + fs_data = np.zeros((1, seq_length)) + # Note: No actual FS site information in sliding window prediction, so keep empty or show predicted sites + # Show high-confidence predictions as potential FS sites + for _, row in results_df.iterrows(): + pos = int(row['Position']) + if (row['Short_Probability'] >= short_threshold and + row['Long_Probability'] >= long_threshold and + row['Ensemble_Probability'] >= 0.8): # High confidence threshold + half_width = desired_visual_width // 2 + start_pos = max(0, pos - half_width) + end_pos = min(seq_length, pos + half_width + 1) + fs_data[0, start_pos:end_pos] = 1 # Use fixed value, no gradient + + ax0.imshow(fs_data, cmap='Reds', aspect='auto', interpolation='nearest') + ax0.set_xticks([]) + ax0.set_yticks([]) + ax0.set_title('FS site', pad=2, fontsize=8) + + # Prediction probability heatmap - using fixed width to display probabilities + ax1 = fig.add_subplot(gs[1]) + prob_data = np.zeros((1, seq_length)) + + # Apply dual threshold filtering + for _, row in results_df.iterrows(): + pos = int(row['Position']) + if (row['Short_Probability'] >= short_threshold and + row['Long_Probability'] >= long_threshold): + # Set fixed width for each probability value + start = max(0, pos - prob_width//2) + end = min(seq_length, pos + prob_width//2 + 1) + prob_data[0, start:end] = row['Ensemble_Probability'] + + im = ax1.imshow(prob_data, cmap='Reds', aspect='auto', vmin=0, vmax=1, interpolation='nearest') + ax1.set_xticks([]) + ax1.set_yticks([]) + ax1.set_title('Prediction', pad=2, fontsize=8) + + # Main plot (bar chart) + ax2 = fig.add_subplot(gs[2]) + + # Apply filtering thresholds + filtered_probs = results_df['Ensemble_Probability'].copy() + mask = ((results_df['Short_Probability'] < short_threshold) | + (results_df['Long_Probability'] < long_threshold)) + filtered_probs[mask] = 0 + + # Draw bar chart - use black color and alpha=0.6 to match prediction_sample style + ax2.bar(results_df['Position'], filtered_probs, + alpha=0.6, color='black', width=1.0) + + # Set x-axis ticks + step = max(seq_length // 10, 50) + ax2.set_xticks(np.arange(0, seq_length, step)) + ax2.tick_params(axis='x', rotation=45) + + # Set labels + ax2.set_xlabel('Position') + ax2.set_ylabel('Probability') + + # Set y-axis range + ax2.set_ylim(0, 1) + + # Add grid + ax2.grid(True, alpha=0.3) + + # Ensure all subplots have consistent x-axis range + for ax in [ax0, ax1, ax2]: + ax.set_xlim(-1, seq_length) + + # Adjust layout + plt.tight_layout() + + # Save plot if save path is provided + if save_path: + plt.savefig(save_path, dpi=dpi, bbox_inches='tight') + # Also save PDF version + if save_path.endswith('.png'): + pdf_path = save_path.replace('.png', '.pdf') + plt.savefig(pdf_path, bbox_inches='tight') + print(f"Plot saved to: {save_path}") + + return results_df, fig + + except Exception as e: + raise Exception(f"Error plotting sequence prediction: {str(e)}") + + def predict_regions(self, sequences, short_threshold=0.1, ensemble_weight=0.4): + ''' + Predict region sequences (batch prediction of known 399bp sequences) + + Args: + sequences: 399bp sequences or DataFrame/Series/list containing 399bp sequences + short_threshold: Short model probability threshold (default: 0.1) + ensemble_weight: Weight of short model in ensemble (default: 0.4) + + Returns: + DataFrame: DataFrame containing prediction probabilities for all sequences + ''' + try: + # Validate weight parameter + if not (0.0 <= ensemble_weight <= 1.0): + raise ValueError("ensemble_weight must be between 0.0 and 1.0") + + # Unify input format + if isinstance(sequences, (pd.DataFrame, pd.Series)): + sequences = sequences.tolist() + elif isinstance(sequences, str): + sequences = [sequences] + + results = [] + for i, seq399 in enumerate(sequences): + try: + # Extract central 33bp from 399bp sequence (for short model use) + seq33 = self._extract_center_sequence(seq399, target_length=self.short_seq_length) + + # Use unified prediction method + pred_result = self.predict_single_position(seq33, seq399, short_threshold, ensemble_weight) + pred_result.update({ + 'Short_Sequence': seq33, + 'Long_Sequence': seq399 + }) + + results.append(pred_result) + + except Exception as e: + print(f"Error processing sequence {i+1}: {str(e)}") + long_weight = 1.0 - ensemble_weight + results.append({ + 'Short_Probability': 0.0, + 'Long_Probability': 0.0, + 'Ensemble_Probability': 0.0, + 'Ensemble_Weights': f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}', + 'Short_Sequence': self._extract_center_sequence(seq399, target_length=self.short_seq_length) if len(seq399) >= self.short_seq_length else seq399, + 'Long_Sequence': seq399 + }) + + return pd.DataFrame(results) + + except Exception as e: + raise Exception(f"Error in region prediction process: {str(e)}") + + def _extract_center_sequence(self, sequence, target_length=33): + """Extract subsequence of specified length from center position of sequence""" + # Ensure sequence is string + sequence = str(sequence).upper() + + # If sequence length is less than target length, return original sequence + if len(sequence) <= target_length: + return sequence + + # Calculate center position + center = len(sequence) // 2 + half_target = target_length // 2 + + # Extract center sequence + start = center - half_target + end = start + target_length + + # Boundary check + if start < 0: + start = 0 + end = target_length + elif end > len(sequence): + end = len(sequence) + start = end - target_length + + return sequence[start:end] + + # 兼容性方法(向后兼容,但标记为废弃) + def predict_full(self, sequence, window_size=3, short_threshold=0.1, short_weight=0.4, plot=False): + """ + ⚠️ 已废弃:请使用 predict_sequence() 方法 + + 向后兼容的方法,内部调用新的 predict_sequence() + """ + import warnings + warnings.warn("predict_full() 已废弃,请使用 predict_sequence() 方法", DeprecationWarning, stacklevel=2) + + # 调用新方法并添加兼容性字段 + results_df = self.predict_sequence(sequence, window_size, short_threshold, short_weight) + + # 添加兼容性字段 + if 'Ensemble_Probability' in results_df.columns: + results_df['Voting_Probability'] = results_df['Ensemble_Probability'] + results_df['Weighted_Probability'] = results_df['Ensemble_Probability'] + if 'Ensemble_Weights' in results_df.columns: + results_df['Weight_Info'] = results_df['Ensemble_Weights'] + if 'Short_Sequence' in results_df.columns: + results_df['33bp'] = results_df['Short_Sequence'] + if 'Long_Sequence' in results_df.columns: + results_df['399bp'] = results_df['Long_Sequence'] + + if plot: + # 如果需要绘图,调用绘图方法 + _, fig = self.plot_sequence_prediction(sequence, window_size, 0.65, 0.8, short_weight) + return results_df, fig + + return results_df + + def predict_region(self, seq, short_threshold=0.1, short_weight=0.4): + """ + ⚠️ 已废弃:请使用 predict_regions() 方法 + + 向后兼容的方法,内部调用新的 predict_regions() + """ + import warnings + warnings.warn("predict_region() 已废弃,请使用 predict_regions() 方法", DeprecationWarning, stacklevel=2) + + # 调用新方法并添加兼容性字段 + results_df = self.predict_regions(seq, short_threshold, short_weight) + + # 添加兼容性字段 + if 'Ensemble_Probability' in results_df.columns: + results_df['Voting_Probability'] = results_df['Ensemble_Probability'] + results_df['Weighted_Probability'] = results_df['Ensemble_Probability'] + if 'Ensemble_Weights' in results_df.columns: + results_df['Weight_Info'] = results_df['Ensemble_Weights'] + if 'Short_Sequence' in results_df.columns: + results_df['33bp'] = results_df['Short_Sequence'] + if 'Long_Sequence' in results_df.columns: + results_df['399bp'] = results_df['Long_Sequence'] + return results_df \ No newline at end of file diff --git a/FScanpy_Demo.ipynb b/FScanpy_Demo.ipynb index 4937334..ff219fa 100644 --- a/FScanpy_Demo.ipynb +++ b/FScanpy_Demo.ipynb @@ -6,65 +6,58 @@ "source": [ "# FScanpy \n", "\n", - "这个 Notebook 展示了如何使用 FScanpy 的真实测试数据进行完整的 PRF 位点预测分析,包括:\n", + "This notebook demonstrates how to use FScanpy with real test data for complete PRF site prediction analysis, including:\n", "\n", - "## 🎯 完整工作流程\n", - "1. **加载测试数据** - 使用内置的真实测试数据\n", - "2. **FScanR 分析** - 从 BLASTX 结果识别潜在 PRF 位点\n", - "3. **序列提取** - 提取 PRF 位点周围的序列\n", - "4. **FScanpy 预测** - 使用机器学习模型预测概率\n", - "5. **结果可视化** - 使用内置绘图函数生成预测结果图表\n", - "6. **序列级预测演示** - 完整序列的滑动窗口分析\n", + "## 🎯 Complete Workflow\n", + "1. **Load Test Data** - Use built-in real test data\n", + "2. **FScanR Analysis** - Identify potential PRF sites from BLASTX results\n", + "3. **Sequence Extraction** - Extract sequences around PRF sites\n", + "4. **FScanpy Prediction** - Use machine learning models to predict probabilities\n", + "5. **Results Visualization** - Generate prediction result plots using built-in plotting functions\n", + "6. **Sequence-level Prediction Demo** - Sliding window analysis of complete sequences\n", "\n", - "## 📊 数据说明\n", - "- **blastx_example.xlsx**: 真实BLASTX比对结果\n", - "- **mrna_example.fasta**: 真实mRNA序列数据\n", - "- **region_example.csv**: 单独对某个位点进行预测的样本" + "## 📊 Data Description\n", + "- **blastx_example.xlsx**: Real BLASTX alignment results\n", + "- **mrna_example.fasta**: Real mRNA sequence data\n", + "- **region_example.csv**: Sample for individual site prediction" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## 📦 环境准备和数据加载" + "## 📦 Environment Setup and Data Loading" ] }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 3, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "✅ 环境准备完成!\n", - "📋 可用的测试数据:\n" + "ename": "ImportError", + "evalue": "cannot import name 'PRFPredictor' from 'FScanpy' (unknown location)", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[3], line 6\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpyplot\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mplt\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;66;03m# Import FScanpy related modules\u001b[39;00m\n\u001b[0;32m----> 6\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mFScanpy\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m PRFPredictor, predict_prf, plot_prf_prediction\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mFScanpy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdata\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m get_test_data_path, list_test_data\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mFScanpy\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m fscanr, extract_prf_regions\n", + "\u001b[0;31mImportError\u001b[0m: cannot import name 'PRFPredictor' from 'FScanpy' (unknown location)" ] - }, - { - "data": { - "text/plain": [ - "['blastx_example.xlsx', 'mrna_example.fasta', 'region_example.csv']" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" } ], "source": [ - "# 导入必要的库\n", + "# Import necessary libraries\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "\n", - "# 导入FScanpy相关模块\n", + "# Import FScanpy related modules\n", "from FScanpy import PRFPredictor, predict_prf, plot_prf_prediction\n", "from FScanpy.data import get_test_data_path, list_test_data\n", "from FScanpy.utils import fscanr, extract_prf_regions\n", "\n", - "print(\"✅ 环境准备完成!\")\n", - "print(\"📋 可用的测试数据:\")\n", + "print(\"✅ Environment setup complete!\")\n", + "print(\"📋 Available test data:\")\n", "list_test_data()" ] }, @@ -72,9 +65,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 1. 加载和探索测试数据\n", + "## 1. Load and Explore Test Data\n", "\n", - "首先加载 FScanpy 提供的真实测试数据,了解数据结构。" + "First, load the real test data provided by FScanpy to understand the data structure." ] }, { @@ -107,25 +100,25 @@ } ], "source": [ - "# 获取测试数据路径\n", + "# Get test data paths\n", "blastx_file = get_test_data_path('blastx_example.xlsx')\n", "mrna_file = get_test_data_path('mrna_example.fasta')\n", "region_file = get_test_data_path('region_example.csv')\n", "\n", - "print(f\"📁 数据文件路径:\")\n", - "print(f\" BLASTX数据: {blastx_file}\")\n", - "print(f\" mRNA序列: {mrna_file}\")\n", - "print(f\" 验证区域: {region_file}\")\n", + "print(f\"📁 Data file paths:\")\n", + "print(f\" BLASTX data: {blastx_file}\")\n", + "print(f\" mRNA sequences: {mrna_file}\")\n", + "print(f\" Validation regions: {region_file}\")\n", "\n", - "# 加载BLASTX数据\n", + "# Load BLASTX data\n", "blastx_data = pd.read_excel(blastx_file)\n", - "print(f\"\\n🧬 BLASTX数据概览:\")\n", - "print(f\" 数据形状: {blastx_data.shape}\")\n", - "print(f\" 列名: {list(blastx_data.columns)}\")\n", - "print(f\" 唯一序列数: {blastx_data['DNA_seqid'].nunique()}\")\n", + "print(f\"\\n🧬 BLASTX data overview:\")\n", + "print(f\" Data shape: {blastx_data.shape}\")\n", + "print(f\" Column names: {list(blastx_data.columns)}\")\n", + "print(f\" Unique sequences: {blastx_data['DNA_seqid'].nunique()}\")\n", "\n", - "# 显示前几行\n", - "print(\"\\n📊 BLASTX数据示例:\")\n", + "# Display first few rows\n", + "print(\"\\n📊 BLASTX data examples:\")\n", "display_cols = ['DNA_seqid', 'Pep_seqid', 'pident', 'length', 'evalue', 'qframe']\n", "print(blastx_data[display_cols].head())" ] @@ -163,21 +156,21 @@ } ], "source": [ - "# 加载验证区域数据\n", + "# Load validation region data\n", "region_data = pd.read_csv(region_file)\n", - "print(f\"🎯 验证区域数据概览:\")\n", - "print(f\" 数据形状: {region_data.shape}\")\n", - "print(f\" 列名: {list(region_data.columns)}\")\n", - "print(f\" 数据来源: {region_data['source'].value_counts().to_dict()}\")\n", + "print(f\"🎯 Validation region data overview:\")\n", + "print(f\" Data shape: {region_data.shape}\")\n", + "print(f\" Column names: {list(region_data.columns)}\")\n", + "print(f\" Data sources: {region_data['source'].value_counts().to_dict()}\")\n", "\n", - "print(\"\\n📋 验证区域数据示例:\")\n", + "print(\"\\n📋 Validation region data examples:\")\n", "display_cols = ['fs_position', 'DNA_seqid', 'label', 'source', 'FS_type']\n", "print(region_data[display_cols].head())\n", "\n", - "# 统计分析\n", - "print(f\"\\n📈 标签分布:\")\n", + "# Statistical analysis\n", + "print(f\"\\n📈 Label distribution:\")\n", "print(region_data['label'].value_counts())\n", - "print(f\"\\n🔬 FS类型分布:\")\n", + "print(f\"\\n🔬 FS type distribution:\")\n", "print(region_data['FS_type'].value_counts())" ] }, @@ -185,9 +178,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 2. FScanR 分析 - 从 BLASTX 识别潜在 PRF 位点\n", + "## 2. FScanR Analysis - Identify Potential PRF Sites from BLASTX\n", "\n", - "使用 FScanR 算法分析 BLASTX 结果,识别潜在的程序性核糖体移码位点。" + "Use the FScanR algorithm to analyze BLASTX results and identify potential programmed ribosomal frameshift sites." ] }, { @@ -229,9 +222,9 @@ } ], "source": [ - "# 运行FScanR分析\n", - "print(\"🔍 运行FScanR分析...\")\n", - "print(\"参数设置: mismatch_cutoff=10, evalue_cutoff=1e-5, frameDist_cutoff=10\")\n", + "# Run FScanR analysis\n", + "print(\"🔍 Running FScanR analysis...\")\n", + "print(\"Parameter settings: mismatch_cutoff=10, evalue_cutoff=1e-5, frameDist_cutoff=100\")\n", "\n", "fscanr_results = fscanr(\n", " blastx_data,\n", @@ -240,29 +233,29 @@ " frameDist_cutoff=100\n", ")\n", "\n", - "print(f\"\\n✅ FScanR分析完成!\")\n", - "print(f\"检测到的潜在PRF位点数量: {len(fscanr_results)}\")\n", + "print(f\"\\n✅ FScanR analysis complete!\")\n", + "print(f\"Number of potential PRF sites detected: {len(fscanr_results)}\")\n", "\n", "if len(fscanr_results) > 0:\n", - " print(f\"\\n📊 FScanR结果概览:\")\n", - " print(f\" 列名: {list(fscanr_results.columns)}\")\n", - " print(f\" 涉及的序列数: {fscanr_results['DNA_seqid'].nunique()}\")\n", - " print(f\" 链方向分布: {fscanr_results['Strand'].value_counts().to_dict()}\")\n", - " print(f\" FS类型分布: {fscanr_results['FS_type'].value_counts().to_dict()}\")\n", + " print(f\"\\n📊 FScanR results overview:\")\n", + " print(f\" Column names: {list(fscanr_results.columns)}\")\n", + " print(f\" Number of sequences involved: {fscanr_results['DNA_seqid'].nunique()}\")\n", + " print(f\" Strand orientation distribution: {fscanr_results['Strand'].value_counts().to_dict()}\")\n", + " print(f\" FS type distribution: {fscanr_results['FS_type'].value_counts().to_dict()}\")\n", " \n", - " print(\"\\n🎯 FScanR结果示例:\")\n", + " print(\"\\n🎯 FScanR results examples:\")\n", " print(fscanr_results.head())\n", "else:\n", - " print(\"⚠️ 未检测到PRF位点,可能需要调整参数\")" + " print(\"⚠️ No PRF sites detected, may need to adjust parameters\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## 3. 序列提取 - 获取 PRF 位点周围序列\n", + "## 3. Sequence Extraction - Extract Sequences Around PRF Sites\n", "\n", - "从 mRNA 序列中提取 FScanR 识别的 PRF 位点周围的序列片段。" + "Extract sequence fragments around PRF sites identified by FScanR from mRNA sequences." ] }, { @@ -306,36 +299,36 @@ } ], "source": [ - "# 提取PRF位点周围的序列\n", + "# Extract sequences around PRF sites\n", "if len(fscanr_results) > 0:\n", - " print(\"📝 从mRNA序列中提取PRF位点周围序列...\")\n", + " print(\"📝 Extracting sequences around PRF sites from mRNA sequences...\")\n", " \n", " prf_sequences = extract_prf_regions(\n", " mrna_file=mrna_file,\n", " prf_data=fscanr_results\n", " )\n", " \n", - " print(f\"\\n✅ 序列提取完成!\")\n", - " print(f\"成功提取的序列数量: {len(prf_sequences)}\")\n", + " print(f\"\\n✅ Sequence extraction complete!\")\n", + " print(f\"Number of successfully extracted sequences: {len(prf_sequences)}\")\n", " \n", " if len(prf_sequences) > 0:\n", - " print(f\"\\n📏 序列长度验证:\")\n", + " print(f\"\\n📏 Sequence length validation:\")\n", " seq_lengths = prf_sequences['399bp'].str.len()\n", - " print(f\" 399bp序列长度分布: {seq_lengths.value_counts().to_dict()}\")\n", - " print(f\" 平均长度: {seq_lengths.mean():.1f}\")\n", + " print(f\" 399bp sequence length distribution: {seq_lengths.value_counts().to_dict()}\")\n", + " print(f\" Average length: {seq_lengths.mean():.1f}\")\n", " \n", - " print(\"\\n🧬 提取序列示例:\")\n", + " print(\"\\n🧬 Extracted sequence examples:\")\n", " for i, row in prf_sequences.head(3).iterrows():\n", - " print(f\"序列 {i+1}: {row['DNA_seqid']}\")\n", - " print(f\" FS位置: {row['FS_start']}-{row['FS_end']}\")\n", - " print(f\" 链方向: {row['Strand']}\")\n", - " print(f\" FS类型: {row['FS_type']}\")\n", - " print(f\" 序列片段: {row['399bp'][:50]}...{row['399bp'][-20:]}\")\n", + " print(f\"Sequence {i+1}: {row['DNA_seqid']}\")\n", + " print(f\" FS position: {row['FS_start']}-{row['FS_end']}\")\n", + " print(f\" Strand orientation: {row['Strand']}\")\n", + " print(f\" FS type: {row['FS_type']}\")\n", + " print(f\" Sequence fragment: {row['399bp'][:50]}...{row['399bp'][-20:]}\")\n", " print()\n", " else:\n", - " print(\"❌ 序列提取失败\")\n", + " print(\"❌ Sequence extraction failed\")\n", "else:\n", - " print(\"⚠️ 跳过序列提取 - 无FScanR结果\")\n", + " print(\"⚠️ Skipping sequence extraction - no FScanR results\")\n", " prf_sequences = pd.DataFrame()" ] }, @@ -343,9 +336,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 4. FScanpy 预测 - 机器学习模型分析\n", + "## 4. FScanpy Prediction - Machine Learning Model Analysis\n", "\n", - "使用 FScanpy 的机器学习模型对提取的序列进行 PRF 概率预测。" + "Use FScanpy's machine learning models to predict PRF probabilities for the extracted sequences." ] }, { @@ -379,26 +372,26 @@ } ], "source": [ - "# 初始化预测器\n", + "# Initialize predictor\n", "predictor = PRFPredictor()\n", - "print(\"🤖 FScanpy预测器初始化完成\")\n", + "print(\"🤖 FScanpy predictor initialization complete\")\n", "\n", - "# 对FScanR识别的序列进行预测\n", + "# Predict FScanR identified sequences\n", "if len(prf_sequences) > 0:\n", - " print(f\"\\n🎯 对 {len(prf_sequences)} 个FScanR识别的序列进行预测...\")\n", + " print(f\"\\n🎯 Predicting {len(prf_sequences)} sequences identified by FScanR...\")\n", " \n", " fscanr_predictions = predictor.predict_regions(\n", " sequences=prf_sequences['399bp'],\n", - " ensemble_weight=0.4 # 平衡配置\n", + " ensemble_weight=0.4 # Balanced configuration\n", " )\n", " \n", - " # 合并结果\n", + " # Merge results\n", " fscanr_predictions = pd.concat([\n", " prf_sequences.reset_index(drop=True),\n", " fscanr_predictions.reset_index(drop=True)\n", " ], axis=1)\n", " \n", - " print(\"\\n📊 FScanR+FScanpy预测结果:\")\n", + " print(\"\\n📊 FScanR+FScanpy prediction results:\")\n", " result_cols = ['DNA_seqid', 'FS_start', 'FS_type', 'Short_Probability', 'Long_Probability', 'Ensemble_Probability']\n", " print(fscanr_predictions[result_cols].head())" ] @@ -429,15 +422,15 @@ } ], "source": [ - "# 对验证区域数据进行预测\n", - "print(f\"\\n🧪 对 {len(region_data)} 个验证区域进行预测...\")\n", + "# Predict validation region data\n", + "print(f\"\\n🧪 Predicting {len(region_data)} validation regions...\")\n", "\n", "validation_predictions = predict_prf(\n", " data=region_data.rename(columns={'399bp': 'Long_Sequence'}),\n", " ensemble_weight=0.4\n", ")\n", "\n", - "print(\"\\n📊 验证区域预测结果:\")\n", + "print(\"\\n📊 Validation region prediction results:\")\n", "result_cols = ['DNA_seqid', 'label', 'source', 'Short_Probability', 'Long_Probability', 'Ensemble_Probability']\n", "print(validation_predictions[result_cols].head())" ] @@ -446,9 +439,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 5. 序列级预测和可视化\n", + "## 5. Sequence-level Prediction and Visualization\n", "\n", - "选择一个具体的mRNA序列,使用内置的plot_prf_prediction函数进行完整的滑动窗口预测和可视化。" + "Select a specific mRNA sequence and use the built-in plot_prf_prediction function for complete sliding window prediction and visualization." ] }, { @@ -557,19 +550,19 @@ } ], "source": [ - "# 选择一个序列进行演示\n", + "# Select a sequence for demonstration\n", "from Bio import SeqIO\n", "\n", - "# 读取第一个mRNA序列作为演示\n", + "# Read the first mRNA sequence for demonstration\n", "mrna_sequences = list(SeqIO.parse(mrna_file, \"fasta\"))\n", - "demo_seq = mrna_sequences[0] # 选择第一个序列\n", + "demo_seq = mrna_sequences[0] # Select the first sequence\n", "\n", - "print(f\"🧬 选择演示序列: {demo_seq.id}\")\n", - "print(f\"序列长度: {len(demo_seq.seq)} bp\")\n", - "print(f\"序列前100bp: {str(demo_seq.seq)[:100]}...\")\n", + "print(f\"🧬 Selected demonstration sequence: {demo_seq.id}\")\n", + "print(f\"Sequence length: {len(demo_seq.seq)} bp\")\n", + "print(f\"First 100bp of sequence: {str(demo_seq.seq)[:100]}...\")\n", "\n", - "# 使用内置的plot_prf_prediction函数进行预测和可视化\n", - "print(f\"\\n🎯 使用plot_prf_prediction进行序列预测和可视化...\")\n", + "# Use built-in plot_prf_prediction function for prediction and visualization\n", + "print(f\"\\n🎯 Using plot_prf_prediction for sequence prediction and visualization...\")\n", "\n", "sequence_results, fig = plot_prf_prediction(\n", " sequence=str(demo_seq.seq),\n", @@ -577,18 +570,18 @@ " short_threshold=0.2,\n", " long_threshold=0.2,\n", " ensemble_weight=0.6,\n", - " title=f\"序列 {demo_seq.id} 的PRF预测结果(条形图+热图)\",\n", + " title=f\"PRF Prediction Results for Sequence {demo_seq.id} (Bar Chart + Heatmap)\",\n", " figsize=(16, 8),\n", " dpi=150\n", ")\n", "\n", "plt.show()\n", "\n", - "print(f\"\\n📊 序列预测结果统计:\")\n", - "print(f\" 预测位点总数: {len(sequence_results)}\")\n", - "print(f\" 高概率位点 (>0.8): {(sequence_results['Ensemble_Probability'] > 0.8).sum()}\")\n", - "print(f\" 中概率位点 (0.4-0.8): {((sequence_results['Ensemble_Probability'] >= 0.4) & (sequence_results['Ensemble_Probability'] <= 0.8)).sum()}\")\n", - "print(f\" 最高预测概率: {sequence_results['Ensemble_Probability'].max():.3f}\")" + "print(f\"\\n📊 Sequence prediction result statistics:\")\n", + "print(f\" Total predicted sites: {len(sequence_results)}\")\n", + "print(f\" High probability sites (>0.8): {(sequence_results['Ensemble_Probability'] > 0.8).sum()}\")\n", + "print(f\" Medium probability sites (0.4-0.8): {((sequence_results['Ensemble_Probability'] >= 0.4) & (sequence_results['Ensemble_Probability'] <= 0.8)).sum()}\")\n", + "print(f\" Highest prediction probability: {sequence_results['Ensemble_Probability'].max():.3f}\")" ] }, { @@ -634,53 +627,53 @@ } ], "source": [ - "# 打印Top预测位点的概率\n", + "# Print top predicted site probabilities\n", "if sequence_results['Ensemble_Probability'].max() > 0.3:\n", " top_predictions = sequence_results.nlargest(5, 'Ensemble_Probability')\n", - " print(f\"\\n🔝 Top 5 预测位点:\")\n", + " print(f\"\\n🔝 Top 5 predicted sites:\")\n", " for i, (_, row) in enumerate(top_predictions.iterrows(), 1):\n", - " print(f\" {i}. 位置 {row['Position']}: \")\n", - " print(f\" - Short概率: {row['Short_Probability']:.3f}\")\n", - " print(f\" - Long概率: {row['Long_Probability']:.3f}\")\n", - " print(f\" - 集成概率: {row['Ensemble_Probability']:.3f}\")\n", - " print(f\" - 密码子: {row['Codon']}\")\n", + " print(f\" {i}. Position {row['Position']}: \")\n", + " print(f\" - Short probability: {row['Short_Probability']:.3f}\")\n", + " print(f\" - Long probability: {row['Long_Probability']:.3f}\")\n", + " print(f\" - Ensemble probability: {row['Ensemble_Probability']:.3f}\")\n", + " print(f\" - Codon: {row['Codon']}\")\n", "else:\n", - " print(\"\\n💡 该序列没有检测到高概率的PRF位点\")\n", + " print(\"\\n💡 No high-probability PRF sites detected in this sequence\")\n", "\n", - "print(\"\\n📊 可视化分析完成!\")\n", - "print(\"图表包含热图和条形图,展示了整个序列的PRF预测概率分布。\")" + "print(\"\\n📊 Visualization analysis complete!\")\n", + "print(\"The chart contains heatmaps and bar charts showing the PRF prediction probability distribution across the entire sequence.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## 📝 分析总结\n", + "## 📝 Analysis Summary\n", "\n", - "### 🎯 主要发现\n", - "1. **数据质量**: 测试数据集包含真实的BLASTX比对结果和验证区域\n", - "2. **FScanR效果**: 从BLASTX结果中识别出潜在PRF位点\n", - "3. **模型性能**: Short和Long模型在不同场景下各有优势\n", - "4. **预测结果**: 集成模型提供了更稳定的预测性能\n", - "5. **可视化**: 内置绘图函数生成清晰的热图和条形图\n", + "### 🎯 Key Findings\n", + "1. **Data Quality**: Test dataset contains real BLASTX alignment results and validation regions\n", + "2. **FScanR Performance**: Successfully identified potential PRF sites from BLASTX results\n", + "3. **Model Performance**: Short and Long models each have advantages in different scenarios\n", + "4. **Prediction Results**: Ensemble model provides more stable prediction performance\n", + "5. **Visualization**: Built-in plotting functions generate clear heatmaps and bar charts\n", "\n", - "### 🔧 最佳实践\n", - "- **数据预处理**: 确保BLASTX结果格式正确\n", - "- **参数设置**: 使用默认的集成权重(0.4:0.6)获得平衡性能\n", - "- **结果解读**: 在使用FScanpy对整条序列进行预测时,不应该使用0.5作为阈值,而应该比较不同位置的概率高低\n", - "- **可视化**: 使用plot_prf_prediction函数生成标准化图表\n", + "### 🔧 Best Practices\n", + "- **Data Preprocessing**: Ensure BLASTX results are in correct format\n", + "- **Parameter Settings**: Use default ensemble weights (0.4:0.6) for balanced performance\n", + "- **Result Interpretation**: When using FScanpy for whole sequence prediction, don't use 0.5 as threshold, but compare relative probabilities across positions\n", + "- **Visualization**: Use plot_prf_prediction function to generate standardized plots\n", "\n", - "### 📚 使用建议\n", - "1. **阈值选择**: 根据应用场景调整概率阈值\n", - "2. **结果验证**: 结合生物学知识验证预测结果\n", - "3. **性能优化**: 对于大规模数据使用合理的滑动窗口大小\n", - "4. **可视化参数**: 调整figsize和dpi获得最佳显示效果" + "### 📚 Usage Recommendations\n", + "1. **Threshold Selection**: Adjust probability thresholds based on application scenarios\n", + "2. **Result Validation**: Validate prediction results with biological knowledge\n", + "3. **Performance Optimization**: Use reasonable sliding window sizes for large-scale data\n", + "4. **Visualization Parameters**: Adjust figsize and dpi for optimal display" ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "tf200", "language": "python", "name": "python3" }, @@ -694,7 +687,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.5" + "version": "3.9.0" } }, "nbformat": 4,