80 lines
2.9 KiB
Python
80 lines
2.9 KiB
Python
|
|
import numpy as np
|
|||
|
|
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
|||
|
|
from typing import List, Union
|
|||
|
|
|
|||
|
|
class CNNInputProcessor:
|
|||
|
|
"""CNN模型输入数据处理器"""
|
|||
|
|
|
|||
|
|
def __init__(self, max_length: int = 399):
|
|||
|
|
self.max_length = max_length
|
|||
|
|
self.base_to_num = {'A': 0, 'T': 1, 'C': 2, 'G': 3, 'N': 4}
|
|||
|
|
|
|||
|
|
def trim_sequence(self, seq, target_length):
|
|||
|
|
"""
|
|||
|
|
从序列两端等量截取,使其达到目标长度,保持中心位置不变
|
|||
|
|
|
|||
|
|
参数:
|
|||
|
|
seq: 原始序列
|
|||
|
|
target_length: 目标长度
|
|||
|
|
|
|||
|
|
返回:
|
|||
|
|
截取后的序列
|
|||
|
|
"""
|
|||
|
|
if len(seq) <= target_length:
|
|||
|
|
return seq
|
|||
|
|
|
|||
|
|
# 计算需要从每端截取的长度
|
|||
|
|
excess = len(seq) - target_length
|
|||
|
|
trim_each_side = excess // 2
|
|||
|
|
|
|||
|
|
# 从两端等量截取,保持中心位置不变
|
|||
|
|
return seq[trim_each_side:len(seq)-trim_each_side]
|
|||
|
|
|
|||
|
|
def prepare_sequence(self, sequence: str) -> np.ndarray:
|
|||
|
|
"""
|
|||
|
|
处理单个序列
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
sequence: DNA序列
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
np.ndarray: 处理后的序列数组
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
# 序列验证和预处理
|
|||
|
|
if not isinstance(sequence, str):
|
|||
|
|
sequence = str(sequence)
|
|||
|
|
|
|||
|
|
sequence = sequence.upper().replace('U', 'T')
|
|||
|
|
|
|||
|
|
# 如果序列长度不等于目标长度,进行截取
|
|||
|
|
if len(sequence) > self.max_length:
|
|||
|
|
sequence = self.trim_sequence(sequence, self.max_length)
|
|||
|
|
|
|||
|
|
# 使用与训练时相同的编码方式
|
|||
|
|
self.base_to_num = {'A': 0, 'T': 1, 'C': 2, 'G': 3, 'N': 4} # 与SemiBilstmCnn.py中保持一致
|
|||
|
|
|
|||
|
|
# 序列转换为数字
|
|||
|
|
seq_numeric = []
|
|||
|
|
for base in sequence:
|
|||
|
|
seq_numeric.append(self.base_to_num.get(base, 4)) # 未知碱基用4表示
|
|||
|
|
|
|||
|
|
# 填充序列
|
|||
|
|
if len(seq_numeric) < self.max_length:
|
|||
|
|
seq_numeric.extend([4] * (self.max_length - len(seq_numeric)))
|
|||
|
|
|
|||
|
|
# 重塑数据为三维数组 (samples, timesteps, features)
|
|||
|
|
result = np.array(seq_numeric).reshape(1, self.max_length, 1)
|
|||
|
|
|
|||
|
|
# 检查结果维度
|
|||
|
|
if result.ndim != 3:
|
|||
|
|
print(f"警告: CNN输入维度异常 - {result.ndim},应为3")
|
|||
|
|
# 强制修正为正确的维度
|
|||
|
|
result = result.reshape(1, self.max_length, 1)
|
|||
|
|
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"CNN序列处理失败: {str(e)}")
|
|||
|
|
# 出错时返回全零的三维数组
|
|||
|
|
return np.zeros((1, self.max_length, 1))
|