import os import pickle import numpy as np import pandas as pd from tensorflow.keras.models import load_model from .features.sequence import SequenceFeatureExtractor from .features.cnn_input import CNNInputProcessor from .utils import extract_window_sequences import matplotlib.pyplot as plt import joblib class PRFPredictor: def __init__(self, model_dir=None): """ 初始化PRF预测器 Args: model_dir: 模型目录路径(可选) """ if model_dir is None: from pkg_resources import resource_filename model_dir = resource_filename('FScanpy', 'pretrained') try: # 加载模型 - 使用新的命名约定 self.short_model = self._load_pickle(os.path.join(model_dir, 'short.pkl')) # HistGB模型 self.long_model = self._load_pickle(os.path.join(model_dir, 'long.pkl')) # BiLSTM-CNN模型 # 初始化特征提取器和CNN处理器,使用与训练时相同的序列长度 self.short_seq_length = 33 # HistGB使用的序列长度 self.long_seq_length = 399 # BiLSTM-CNN使用的序列长度 # 初始化特征提取器和CNN输入处理器 self.feature_extractor = SequenceFeatureExtractor(seq_length=self.short_seq_length) self.cnn_processor = CNNInputProcessor(max_length=self.long_seq_length) # 检测模型类型以优化预测性能 self._detect_model_types() except FileNotFoundError as e: raise FileNotFoundError(f"无法找到模型文件: {str(e)}。请确保模型文件 'short.pkl' 和 'long.pkl' 存在于 {model_dir}") except Exception as e: raise Exception(f"加载模型出错: {str(e)}") def _load_pickle(self, path): """安全加载pickle文件""" try: return joblib.load(path) except Exception as e: raise FileNotFoundError(f"无法加载模型文件 {path}: {str(e)}") def _detect_model_types(self): """检测模型类型以优化预测性能""" self.short_is_sklearn = hasattr(self.short_model, 'predict_proba') self.long_is_sklearn = hasattr(self.long_model, 'predict_proba') def _predict_model(self, model, features, is_sklearn, seq_length): """统一的模型预测方法""" try: if is_sklearn: # sklearn模型使用特征向量 if isinstance(features, np.ndarray) and features.ndim > 1: features = features.flatten() features_2d = np.array([features]) pred = model.predict_proba(features_2d) return pred[0][1] else: # 深度学习模型 if seq_length == self.long_seq_length: # 对于长序列,使用CNN处理器 model_input = self.cnn_processor.prepare_sequence(features) else: # 对于短序列,转换为数值编码 base_to_num = {'A': 1, 'T': 2, 'G': 3, 'C': 4, 'N': 0} seq_numeric = [base_to_num.get(base, 0) for base in features.upper()] model_input = np.array(seq_numeric).reshape(1, len(seq_numeric), 1) # 统一的预测调用 try: pred = model.predict(model_input, verbose=0) except TypeError: pred = model.predict(model_input) # 处理预测结果 if isinstance(pred, list): pred = pred[0] if hasattr(pred, 'shape') and len(pred.shape) > 1 and pred.shape[1] > 1: return pred[0][1] else: return pred[0][0] if hasattr(pred[0], '__getitem__') else pred[0] except Exception as e: raise Exception(f"模型预测失败: {str(e)}") def predict_single_position(self, fs_period, full_seq, short_threshold=0.1, ensemble_weight=0.4): ''' 预测单个位置的PRF状态 Args: fs_period: 33bp序列 (short模型使用) full_seq: 完整序列 (long模型使用) short_threshold: short模型的概率阈值 (默认为0.1) ensemble_weight: short模型在集成中的权重 (默认为0.4,long权重为0.6) Returns: dict: 包含预测概率的字典 ''' try: # 验证权重参数 if not (0.0 <= ensemble_weight <= 1.0): raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间") long_weight = 1.0 - ensemble_weight # 处理序列长度 if len(fs_period) > self.short_seq_length: fs_period = self.feature_extractor.trim_sequence(fs_period, self.short_seq_length) # Short模型预测 (HistGB) try: if self.short_is_sklearn: short_features = self.feature_extractor.extract_features(fs_period) short_prob = self._predict_model(self.short_model, short_features, True, self.short_seq_length) else: short_prob = self._predict_model(self.short_model, fs_period, False, self.short_seq_length) except Exception as e: print(f"Short模型预测时出错: {str(e)}") short_prob = 0.0 # 如果short概率低于阈值,则跳过long模型 if short_prob < short_threshold: return { 'Short_Probability': short_prob, 'Long_Probability': 0.0, 'Ensemble_Probability': 0.0, 'Ensemble_Weights': f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}' } # Long模型预测 (BiLSTM-CNN) try: if self.long_is_sklearn: long_features = self.feature_extractor.extract_features(full_seq) long_prob = self._predict_model(self.long_model, long_features, True, self.long_seq_length) else: long_prob = self._predict_model(self.long_model, full_seq, False, self.long_seq_length) except Exception as e: print(f"Long模型预测时出错: {str(e)}") long_prob = 0.0 # 计算集成概率 try: ensemble_prob = ensemble_weight * short_prob + long_weight * long_prob except Exception as e: print(f"计算集成概率时出错: {str(e)}") ensemble_prob = (short_prob + long_prob) / 2 return { 'Short_Probability': short_prob, 'Long_Probability': long_prob, 'Ensemble_Probability': ensemble_prob, 'Ensemble_Weights': f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}' } except Exception as e: raise Exception(f"预测过程出错: {str(e)}") def predict_sequence(self, sequence, window_size=3, short_threshold=0.1, ensemble_weight=0.4): """ 预测完整序列中的PRF位点(滑动窗口方法) Args: sequence: 输入DNA序列 window_size: 滑动窗口大小 (默认为3) short_threshold: short模型概率阈值 (默认为0.1) ensemble_weight: short模型在集成中的权重 (默认为0.4) Returns: pd.DataFrame: 包含预测结果的DataFrame """ if window_size < 1: raise ValueError("窗口大小必须大于等于1") if short_threshold < 0: raise ValueError("short模型阈值必须大于等于0") if not (0.0 <= ensemble_weight <= 1.0): raise ValueError("ensemble_weight 必须在 0.0 到 1.0 之间") results = [] long_weight = 1.0 - ensemble_weight try: # 确保序列为字符串并转换为大写 sequence = str(sequence).upper() # 滑动窗口预测 for pos in range(0, len(sequence) - 2, window_size): # 提取窗口序列 fs_period, full_seq = extract_window_sequences(sequence, pos) if fs_period is None or full_seq is None: continue # 预测并记录结果 pred = self.predict_single_position(fs_period, full_seq, short_threshold, ensemble_weight) pred.update({ 'Position': pos, 'Codon': sequence[pos:pos+3], 'Short_Sequence': fs_period, # 更清晰的命名 'Long_Sequence': full_seq # 更清晰的命名 }) results.append(pred) # 创建结果DataFrame results_df = pd.DataFrame(results) return results_df except Exception as e: raise Exception(f"序列预测过程出错: {str(e)}") def plot_sequence_prediction(self, sequence, window_size=3, short_threshold=0.65, long_threshold=0.8, ensemble_weight=0.4, title=None, save_path=None, figsize=(12, 8), dpi=300): """ Plot sequence PRF prediction results Args: sequence: Input DNA sequence window_size: Sliding window size (default: 3) short_threshold: Short model (HistGB) filtering threshold (default: 0.65) long_threshold: Long model (BiLSTM-CNN) filtering threshold (default: 0.8) ensemble_weight: Weight of short model in ensemble (default: 0.4) title: Plot title (optional) save_path: Save path (optional, saves plot if provided) figsize: Figure size (default: (12, 8)) dpi: Figure resolution (default: 300) Returns: tuple: (pd.DataFrame, matplotlib.figure.Figure) prediction results and figure object """ try: # Validate weight parameter if not (0.0 <= ensemble_weight <= 1.0): raise ValueError("ensemble_weight must be between 0.0 and 1.0") long_weight = 1.0 - ensemble_weight # Get prediction results results_df = self.predict_sequence(sequence, window_size=window_size, short_threshold=0.1, ensemble_weight=ensemble_weight) if results_df.empty: raise ValueError("Prediction results are empty, please check input sequence") # Get sequence length seq_length = len(sequence) # Calculate display width desired_visual_width = max(3, seq_length // 100) # FS site width ~1% of sequence length prob_width = max(1, desired_visual_width // 3) # Prediction probability width is 1/3 of FS site width # Create figure with three subplots, set height ratios fig = plt.figure(figsize=figsize) # Set title if title: fig.suptitle(title, y=0.95, fontsize=10) else: fig.suptitle(f'PRF Prediction Results (Weights {ensemble_weight:.1f}:{long_weight:.1f})', y=0.95, fontsize=10) # Adjust subplot ratios, make top two heatmaps smaller gs = fig.add_gridspec(3, 1, height_ratios=[0.1, 0.1, 1], hspace=0.2) # FS site heatmap - using fixed width, no blur effect ax0 = fig.add_subplot(gs[0]) fs_data = np.zeros((1, seq_length)) # Note: No actual FS site information in sliding window prediction, so keep empty or show predicted sites # Show high-confidence predictions as potential FS sites for _, row in results_df.iterrows(): pos = int(row['Position']) if (row['Short_Probability'] >= short_threshold and row['Long_Probability'] >= long_threshold and row['Ensemble_Probability'] >= 0.8): # High confidence threshold half_width = desired_visual_width // 2 start_pos = max(0, pos - half_width) end_pos = min(seq_length, pos + half_width + 1) fs_data[0, start_pos:end_pos] = 1 # Use fixed value, no gradient ax0.imshow(fs_data, cmap='Reds', aspect='auto', interpolation='nearest') ax0.set_xticks([]) ax0.set_yticks([]) ax0.set_title('FS site', pad=2, fontsize=8) # Prediction probability heatmap - using fixed width to display probabilities ax1 = fig.add_subplot(gs[1]) prob_data = np.zeros((1, seq_length)) # Apply dual threshold filtering for _, row in results_df.iterrows(): pos = int(row['Position']) if (row['Short_Probability'] >= short_threshold and row['Long_Probability'] >= long_threshold): # Set fixed width for each probability value start = max(0, pos - prob_width//2) end = min(seq_length, pos + prob_width//2 + 1) prob_data[0, start:end] = row['Ensemble_Probability'] im = ax1.imshow(prob_data, cmap='Reds', aspect='auto', vmin=0, vmax=1, interpolation='nearest') ax1.set_xticks([]) ax1.set_yticks([]) ax1.set_title('Prediction', pad=2, fontsize=8) # Main plot (bar chart) ax2 = fig.add_subplot(gs[2]) # Apply filtering thresholds filtered_probs = results_df['Ensemble_Probability'].copy() mask = ((results_df['Short_Probability'] < short_threshold) | (results_df['Long_Probability'] < long_threshold)) filtered_probs[mask] = 0 # Draw bar chart - use black color and alpha=0.6 to match prediction_sample style ax2.bar(results_df['Position'], filtered_probs, alpha=0.6, color='black', width=1.0) # Set x-axis ticks step = max(seq_length // 10, 50) ax2.set_xticks(np.arange(0, seq_length, step)) ax2.tick_params(axis='x', rotation=45) # Set labels ax2.set_xlabel('Position') ax2.set_ylabel('Probability') # Set y-axis range ax2.set_ylim(0, 1) # Add grid ax2.grid(True, alpha=0.3) # Ensure all subplots have consistent x-axis range for ax in [ax0, ax1, ax2]: ax.set_xlim(-1, seq_length) # Adjust layout plt.tight_layout() # Save plot if save path is provided if save_path: plt.savefig(save_path, dpi=dpi, bbox_inches='tight') # Also save PDF version if save_path.endswith('.png'): pdf_path = save_path.replace('.png', '.pdf') plt.savefig(pdf_path, bbox_inches='tight') print(f"Plot saved to: {save_path}") return results_df, fig except Exception as e: raise Exception(f"Error plotting sequence prediction: {str(e)}") def predict_regions(self, sequences, short_threshold=0.1, ensemble_weight=0.4): ''' Predict region sequences (batch prediction of known 399bp sequences) Args: sequences: 399bp sequences or DataFrame/Series/list containing 399bp sequences short_threshold: Short model probability threshold (default: 0.1) ensemble_weight: Weight of short model in ensemble (default: 0.4) Returns: DataFrame: DataFrame containing prediction probabilities for all sequences ''' try: # Validate weight parameter if not (0.0 <= ensemble_weight <= 1.0): raise ValueError("ensemble_weight must be between 0.0 and 1.0") # Unify input format if isinstance(sequences, (pd.DataFrame, pd.Series)): sequences = sequences.tolist() elif isinstance(sequences, str): sequences = [sequences] results = [] for i, seq399 in enumerate(sequences): try: # Extract central 33bp from 399bp sequence (for short model use) seq33 = self._extract_center_sequence(seq399, target_length=self.short_seq_length) # Use unified prediction method pred_result = self.predict_single_position(seq33, seq399, short_threshold, ensemble_weight) pred_result.update({ 'Short_Sequence': seq33, 'Long_Sequence': seq399 }) results.append(pred_result) except Exception as e: print(f"Error processing sequence {i+1}: {str(e)}") long_weight = 1.0 - ensemble_weight results.append({ 'Short_Probability': 0.0, 'Long_Probability': 0.0, 'Ensemble_Probability': 0.0, 'Ensemble_Weights': f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}', 'Short_Sequence': self._extract_center_sequence(seq399, target_length=self.short_seq_length) if len(seq399) >= self.short_seq_length else seq399, 'Long_Sequence': seq399 }) return pd.DataFrame(results) except Exception as e: raise Exception(f"Error in region prediction process: {str(e)}") def _extract_center_sequence(self, sequence, target_length=33): """Extract subsequence of specified length from center position of sequence""" # Ensure sequence is string sequence = str(sequence).upper() # If sequence length is less than target length, return original sequence if len(sequence) <= target_length: return sequence # Calculate center position center = len(sequence) // 2 half_target = target_length // 2 # Extract center sequence start = center - half_target end = start + target_length # Boundary check if start < 0: start = 0 end = target_length elif end > len(sequence): end = len(sequence) start = end - target_length return sequence[start:end] # 兼容性方法(向后兼容,但标记为废弃) def predict_full(self, sequence, window_size=3, short_threshold=0.1, short_weight=0.4, plot=False): """ ⚠️ 已废弃:请使用 predict_sequence() 方法 向后兼容的方法,内部调用新的 predict_sequence() """ import warnings warnings.warn("predict_full() 已废弃,请使用 predict_sequence() 方法", DeprecationWarning, stacklevel=2) # 调用新方法并添加兼容性字段 results_df = self.predict_sequence(sequence, window_size, short_threshold, short_weight) # 添加兼容性字段 if 'Ensemble_Probability' in results_df.columns: results_df['Voting_Probability'] = results_df['Ensemble_Probability'] results_df['Weighted_Probability'] = results_df['Ensemble_Probability'] if 'Ensemble_Weights' in results_df.columns: results_df['Weight_Info'] = results_df['Ensemble_Weights'] if 'Short_Sequence' in results_df.columns: results_df['33bp'] = results_df['Short_Sequence'] if 'Long_Sequence' in results_df.columns: results_df['399bp'] = results_df['Long_Sequence'] if plot: # 如果需要绘图,调用绘图方法 _, fig = self.plot_sequence_prediction(sequence, window_size, 0.65, 0.8, short_weight) return results_df, fig return results_df def predict_region(self, seq, short_threshold=0.1, short_weight=0.4): """ ⚠️ 已废弃:请使用 predict_regions() 方法 向后兼容的方法,内部调用新的 predict_regions() """ import warnings warnings.warn("predict_region() 已废弃,请使用 predict_regions() 方法", DeprecationWarning, stacklevel=2) # 调用新方法并添加兼容性字段 results_df = self.predict_regions(seq, short_threshold, short_weight) # 添加兼容性字段 if 'Ensemble_Probability' in results_df.columns: results_df['Voting_Probability'] = results_df['Ensemble_Probability'] results_df['Weighted_Probability'] = results_df['Ensemble_Probability'] if 'Ensemble_Weights' in results_df.columns: results_df['Weight_Info'] = results_df['Ensemble_Weights'] if 'Short_Sequence' in results_df.columns: results_df['33bp'] = results_df['Short_Sequence'] if 'Long_Sequence' in results_df.columns: results_df['399bp'] = results_df['Long_Sequence'] return results_df