80 lines
2.9 KiB
Python
80 lines
2.9 KiB
Python
import numpy as np
|
||
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
||
from typing import List, Union
|
||
|
||
class CNNInputProcessor:
|
||
"""CNN模型输入数据处理器"""
|
||
|
||
def __init__(self, max_length: int = 399):
|
||
self.max_length = max_length
|
||
self.base_to_num = {'A': 0, 'T': 1, 'C': 2, 'G': 3, 'N': 4}
|
||
|
||
def trim_sequence(self, seq, target_length):
|
||
"""
|
||
从序列两端等量截取,使其达到目标长度,保持中心位置不变
|
||
|
||
参数:
|
||
seq: 原始序列
|
||
target_length: 目标长度
|
||
|
||
返回:
|
||
截取后的序列
|
||
"""
|
||
if len(seq) <= target_length:
|
||
return seq
|
||
|
||
# 计算需要从每端截取的长度
|
||
excess = len(seq) - target_length
|
||
trim_each_side = excess // 2
|
||
|
||
# 从两端等量截取,保持中心位置不变
|
||
return seq[trim_each_side:len(seq)-trim_each_side]
|
||
|
||
def prepare_sequence(self, sequence: str) -> np.ndarray:
|
||
"""
|
||
处理单个序列
|
||
|
||
Args:
|
||
sequence: DNA序列
|
||
|
||
Returns:
|
||
np.ndarray: 处理后的序列数组
|
||
"""
|
||
try:
|
||
# 序列验证和预处理
|
||
if not isinstance(sequence, str):
|
||
sequence = str(sequence)
|
||
|
||
sequence = sequence.upper().replace('U', 'T')
|
||
|
||
# 如果序列长度不等于目标长度,进行截取
|
||
if len(sequence) > self.max_length:
|
||
sequence = self.trim_sequence(sequence, self.max_length)
|
||
|
||
# 使用与训练时相同的编码方式
|
||
self.base_to_num = {'A': 0, 'T': 1, 'C': 2, 'G': 3, 'N': 4} # 与SemiBilstmCnn.py中保持一致
|
||
|
||
# 序列转换为数字
|
||
seq_numeric = []
|
||
for base in sequence:
|
||
seq_numeric.append(self.base_to_num.get(base, 4)) # 未知碱基用4表示
|
||
|
||
# 填充序列
|
||
if len(seq_numeric) < self.max_length:
|
||
seq_numeric.extend([4] * (self.max_length - len(seq_numeric)))
|
||
|
||
# 重塑数据为三维数组 (samples, timesteps, features)
|
||
result = np.array(seq_numeric).reshape(1, self.max_length, 1)
|
||
|
||
# 检查结果维度
|
||
if result.ndim != 3:
|
||
print(f"警告: CNN输入维度异常 - {result.ndim},应为3")
|
||
# 强制修正为正确的维度
|
||
result = result.reshape(1, self.max_length, 1)
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
print(f"CNN序列处理失败: {str(e)}")
|
||
# 出错时返回全零的三维数组
|
||
return np.zeros((1, self.max_length, 1)) |