import os import pickle import numpy as np import pandas as pd from tensorflow.keras.models import load_model from .features.sequence import SequenceFeatureExtractor from .features.cnn_input import CNNInputProcessor from .utils import extract_window_sequences import matplotlib.pyplot as plt import joblib class PRFPredictor: def __init__(self, model_dir=None): """ 初始化PRF预测器 Args: model_dir: 模型目录路径(可选) """ if model_dir is None: from pkg_resources import resource_filename model_dir = resource_filename('FScanpy', 'pretrained') try: # 加载模型 self.gb_model = self._load_pickle(os.path.join(model_dir, 'GradientBoosting_all.pkl')) self.cnn_model = self._load_pickle(os.path.join(model_dir, 'BiLSTM-CNN_all.pkl')) # 初始化特征提取器和CNN处理器,使用与训练时相同的序列长度 self.gb_seq_length = 33 # HistGradientBoosting使用的序列长度 self.cnn_seq_length = 399 # BiLSTM-CNN使用的序列长度 # 初始化特征提取器和CNN输入处理器 self.feature_extractor = SequenceFeatureExtractor(seq_length=self.gb_seq_length) self.cnn_processor = CNNInputProcessor(max_length=self.cnn_seq_length) except FileNotFoundError as e: raise FileNotFoundError(f"无法找到模型文件: {str(e)}") except Exception as e: raise Exception(f"加载模型出错: {str(e)}") def _load_pickle(self, path): return joblib.load(path) def predict_single_position(self, fs_period, full_seq, gb_threshold=0.1): ''' 预测单个位置的PRF状态 Args: fs_period: 33bp序列 (将根据gb_seq_length处理) full_seq: 完整序列 (将根据cnn_seq_length处理) gb_threshold: GB模型的概率阈值 (默认为0.1) Returns: dict: 包含预测概率的字典 ''' try: # 处理序列长度 if len(fs_period) > self.gb_seq_length: fs_period = self.feature_extractor.trim_sequence(fs_period, self.gb_seq_length) # GB模型预测 - 确保输入是二维数组 try: gb_features = self.feature_extractor.extract_features(fs_period) # 检查特征结构并确保是一维数组 if isinstance(gb_features, np.ndarray): # 如果是多维数组,进行扁平化处理 if gb_features.ndim > 1: print(f"警告: 特征是{gb_features.ndim}维数组,进行扁平化处理") gb_features = gb_features.flatten() # 明确将特征转换为二维数组,正确形状为(1, n_features) gb_features_2d = np.array([gb_features]) # 再次检查维度 if gb_features_2d.ndim != 2: raise ValueError(f"处理后特征仍为{gb_features_2d.ndim}维,需要二维数组") gb_prob = self.gb_model.predict_proba(gb_features_2d)[0][1] except Exception as e: print(f"GB模型预测时出错: {str(e)}") # 出错时设置概率为0 gb_prob = 0.0 # 如果GB概率低于阈值,则跳过CNN模型 if gb_prob < gb_threshold: return { 'GB_Probability': gb_prob, 'CNN_Probability': 0.0, 'Voting_Probability': 0.0 } # CNN模型预测 try: # 首先检查CNN模型的类型 - 通过尝试识别模型类型 is_sklearn_model = False # 检测模型类型的方法 if hasattr(self.cnn_model, 'predict_proba'): # 这可能是一个scikit-learn模型 is_sklearn_model = True if is_sklearn_model: # 如果是sklearn模型 (如HistGradientBoostingClassifier),使用与GB相同的特征提取 # 为CNN模型使用相同的特征提取方法,但从399bp序列中提取 cnn_features = self.feature_extractor.extract_features(full_seq) if isinstance(cnn_features, np.ndarray) and cnn_features.ndim > 1: cnn_features = cnn_features.flatten() # 转为二维数组 cnn_features_2d = np.array([cnn_features]) cnn_pred = self.cnn_model.predict_proba(cnn_features_2d) cnn_prob = cnn_pred[0][1] else: # 假设是深度学习模型,需要三维输入 cnn_input = self.cnn_processor.prepare_sequence(full_seq) # 尝试不同的预测方法 try: # 先尝试不带参数 cnn_pred = self.cnn_model.predict(cnn_input) except TypeError: try: # 再尝试带verbose参数 cnn_pred = self.cnn_model.predict(cnn_input, verbose=0) except Exception: # 最后尝试将输入重塑为2D reshaped_input = cnn_input.reshape(1, -1) cnn_pred = self.cnn_model.predict(reshaped_input) # 处理预测结果 if isinstance(cnn_pred, list): cnn_pred = cnn_pred[0] # 提取概率值 if hasattr(cnn_pred, 'shape') and len(cnn_pred.shape) > 1 and cnn_pred.shape[1] > 1: cnn_prob = cnn_pred[0][1] else: cnn_prob = cnn_pred[0][0] if hasattr(cnn_pred[0], '__getitem__') else cnn_pred[0] except Exception as e: print(f"CNN模型预测时出错: {str(e)}") # 出错时设置概率为0 cnn_prob = 0.0 # 使用4:6的加权平均替代投票模型 try: voting_prob = 0.4 * gb_prob + 0.6 * cnn_prob except Exception as e: print(f"计算加权平均时出错: {str(e)}") # 出错时使用简单平均 voting_prob = (gb_prob + cnn_prob) / 2 return { 'GB_Probability': gb_prob, 'CNN_Probability': cnn_prob, 'Voting_Probability': voting_prob } except Exception as e: raise Exception(f"预测过程出错: {str(e)}") def predict_full(self, sequence, window_size=3, gb_threshold=0.1, plot=False): """ 预测完整序列中的PRF位点 Args: sequence: 输入DNA序列 window_size: 滑动窗口大小 (默认为3) gb_threshold: GB模型概率阈值 (默认为0.1) plot: 是否绘制预测结果图表 (默认为False) Returns: if plot=False: pd.DataFrame: 包含预测结果的DataFrame if plot=True: tuple: (pd.DataFrame, matplotlib.figure.Figure) """ if window_size < 1: raise ValueError("窗口大小必须大于等于1") if gb_threshold < 0: raise ValueError("GB阈值必须大于等于0") results = [] try: # 确保序列为字符串并转换为大写 sequence = str(sequence).upper() # 滑动窗口预测 for pos in range(0, len(sequence) - 2, window_size): # 提取窗口序列 - 使用与训练时相同的窗口大小 fs_period, full_seq = extract_window_sequences(sequence, pos) if fs_period is None or full_seq is None: continue # 预测并记录结果 pred = self.predict_single_position(fs_period, full_seq, gb_threshold) pred.update({ 'Position': pos, 'Codon': sequence[pos:pos+3], '33bp': fs_period, '399bp': full_seq }) results.append(pred) # 创建结果DataFrame results_df = pd.DataFrame(results) # 如需绘图 if plot: # 创建图形 fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 10), height_ratios=[2, 1]) # 绘制折线图 ax1.plot(results_df['Position'], results_df['GB_Probability'], label='GB模型', alpha=0.7, linewidth=1.5) ax1.plot(results_df['Position'], results_df['CNN_Probability'], label='CNN模型', alpha=0.7, linewidth=1.5) ax1.plot(results_df['Position'], results_df['Voting_Probability'], label='投票模型', linewidth=2, color='red') ax1.set_xlabel('序列位置') ax1.set_ylabel('移码概率') ax1.set_title('移码预测概率') ax1.legend() ax1.grid(True, alpha=0.3) # 准备热图数据 positions = results_df['Position'].values probabilities = results_df['Voting_Probability'].values # 创建热图矩阵 heatmap_matrix = np.zeros((1, len(positions))) heatmap_matrix[0, :] = probabilities # 绘制热图 im = ax2.imshow(heatmap_matrix, aspect='auto', cmap='YlOrRd', extent=[min(positions), max(positions), 0, 1]) # 添加颜色条 cbar = plt.colorbar(im, ax=ax2) cbar.set_label('移码概率') # 设置热图轴标签 ax2.set_xlabel('序列位置') ax2.set_title('移码概率热图') ax2.set_yticks([]) # 调整布局 plt.tight_layout() return results_df, fig return results_df except Exception as e: raise Exception(f"序列预测过程出错: {str(e)}") def predict_region(self, seq, gb_threshold=0.1): ''' 预测区域序列 Args: seq: 399bp序列或包含399bp序列的DataFrame/Series gb_threshold: GB模型概率阈值 (默认为0.1) Returns: DataFrame: 包含所有序列预测概率的DataFrame ''' try: # 如果输入是DataFrame或Series,转换为列表 if isinstance(seq, (pd.DataFrame, pd.Series)): seq = seq.tolist() # 如果输入是单个字符串,转换为列表 if isinstance(seq, str): seq = [seq] results = [] for i, seq399 in enumerate(seq): try: # 从399bp序列中截取中心的33bp (GB模型使用) seq33 = self._extract_center_sequence(seq399, target_length=self.gb_seq_length) # GB模型预测 - 确保输入是二维数组 try: gb_features = self.feature_extractor.extract_features(seq33) # 检查特征结构并确保是一维数组 if isinstance(gb_features, np.ndarray): # 如果是多维数组,进行扁平化处理 if gb_features.ndim > 1: print(f"警告: 序列 {i+1} 的特征是{gb_features.ndim}维数组,进行扁平化处理") gb_features = gb_features.flatten() # 明确将特征转换为二维数组,正确形状为(1, n_features) gb_features_2d = np.array([gb_features]) # 再次检查维度 if gb_features_2d.ndim != 2: raise ValueError(f"处理后特征仍为{gb_features_2d.ndim}维,需要二维数组") gb_prob = self.gb_model.predict_proba(gb_features_2d)[0][1] except Exception as e: print(f"GB模型预测序列 {i+1} 时出错: {str(e)}") # 出错时设置概率为0 gb_prob = 0.0 # 如果GB概率低于阈值,添加低概率结果 if gb_prob < gb_threshold: results.append({ 'GB_Probability': gb_prob, 'CNN_Probability': 0.0, 'Voting_Probability': 0.0, '33bp': seq33, '399bp': seq399 }) continue # CNN模型预测 try: # 首先检查CNN模型的类型 - 通过尝试识别模型类型 is_sklearn_model = False # 检测模型类型的方法 if hasattr(self.cnn_model, 'predict_proba'): # 这可能是一个scikit-learn模型 is_sklearn_model = True if is_sklearn_model: # 如果是sklearn模型 (如HistGradientBoostingClassifier),使用与GB相同的特征提取 # 为CNN模型使用相同的特征提取方法,但从399bp序列中提取 cnn_features = self.feature_extractor.extract_features(seq399) if isinstance(cnn_features, np.ndarray) and cnn_features.ndim > 1: cnn_features = cnn_features.flatten() # 转为二维数组 cnn_features_2d = np.array([cnn_features]) cnn_pred = self.cnn_model.predict_proba(cnn_features_2d) cnn_prob = cnn_pred[0][1] else: # 假设是深度学习模型,需要三维输入 cnn_input = self.cnn_processor.prepare_sequence(seq399) # 尝试不同的预测方法 try: # 先尝试不带参数 cnn_pred = self.cnn_model.predict(cnn_input) except TypeError: try: # 再尝试带verbose参数 cnn_pred = self.cnn_model.predict(cnn_input, verbose=0) except Exception: # 最后尝试将输入重塑为2D reshaped_input = cnn_input.reshape(1, -1) cnn_pred = self.cnn_model.predict(reshaped_input) # 处理预测结果 if isinstance(cnn_pred, list): cnn_pred = cnn_pred[0] # 提取概率值 if hasattr(cnn_pred, 'shape') and len(cnn_pred.shape) > 1 and cnn_pred.shape[1] > 1: cnn_prob = cnn_pred[0][1] else: cnn_prob = cnn_pred[0][0] if hasattr(cnn_pred[0], '__getitem__') else cnn_pred[0] except Exception as e: print(f"CNN模型预测序列 {i+1} 时出错: {str(e)}") # 出错时设置概率为0 cnn_prob = 0.0 # 使用4:6的加权平均替代投票模型 try: voting_prob = 0.4 * gb_prob + 0.6 * cnn_prob except Exception as e: print(f"计算加权平均时出错: {str(e)}") # 出错时使用简单平均 voting_prob = (gb_prob + cnn_prob) / 2 results.append({ 'GB_Probability': gb_prob, 'CNN_Probability': cnn_prob, 'Voting_Probability': voting_prob, '33bp': seq33, '399bp': seq399 }) except Exception as e: print(f"处理第 {i+1} 个序列时出错: {str(e)}") results.append({ 'GB_Probability': 0.0, 'CNN_Probability': 0.0, 'Voting_Probability': 0.0, '33bp': self._extract_center_sequence(seq399, target_length=self.gb_seq_length) if len(seq399) >= self.gb_seq_length else seq399, '399bp': seq399 }) return pd.DataFrame(results) except Exception as e: raise Exception(f"区域预测过程出错: {str(e)}") def _extract_center_sequence(self, sequence, target_length=33): """从序列中心位置提取指定长度的子序列""" # 确保序列为字符串 sequence = str(sequence).upper() # 如果序列长度小于目标长度,返回原序列 if len(sequence) <= target_length: return sequence # 计算中心位置 center = len(sequence) // 2 half_target = target_length // 2 # 提取中心序列 start = center - half_target end = start + target_length # 边界检查 if start < 0: start = 0 end = target_length elif end > len(sequence): end = len(sequence) start = end - target_length return sequence[start:end]