FScanpy-package/FScanpy/features/cnn_input.py

80 lines
2.9 KiB
Python
Raw Permalink Normal View History

2025-03-18 11:21:54 +08:00
import numpy as np
from tensorflow.keras.preprocessing.sequence import pad_sequences
from typing import List, Union
class CNNInputProcessor:
"""CNN模型输入数据处理器"""
def __init__(self, max_length: int = 399):
self.max_length = max_length
self.base_to_num = {'A': 0, 'T': 1, 'C': 2, 'G': 3, 'N': 4}
def trim_sequence(self, seq, target_length):
"""
从序列两端等量截取使其达到目标长度保持中心位置不变
参数:
seq: 原始序列
target_length: 目标长度
返回:
截取后的序列
"""
if len(seq) <= target_length:
return seq
# 计算需要从每端截取的长度
excess = len(seq) - target_length
trim_each_side = excess // 2
# 从两端等量截取,保持中心位置不变
return seq[trim_each_side:len(seq)-trim_each_side]
def prepare_sequence(self, sequence: str) -> np.ndarray:
"""
处理单个序列
Args:
sequence: DNA序列
Returns:
np.ndarray: 处理后的序列数组
"""
try:
# 序列验证和预处理
if not isinstance(sequence, str):
sequence = str(sequence)
sequence = sequence.upper().replace('U', 'T')
# 如果序列长度不等于目标长度,进行截取
if len(sequence) > self.max_length:
sequence = self.trim_sequence(sequence, self.max_length)
# 使用与训练时相同的编码方式
self.base_to_num = {'A': 0, 'T': 1, 'C': 2, 'G': 3, 'N': 4} # 与SemiBilstmCnn.py中保持一致
# 序列转换为数字
seq_numeric = []
for base in sequence:
seq_numeric.append(self.base_to_num.get(base, 4)) # 未知碱基用4表示
# 填充序列
if len(seq_numeric) < self.max_length:
seq_numeric.extend([4] * (self.max_length - len(seq_numeric)))
# 重塑数据为三维数组 (samples, timesteps, features)
result = np.array(seq_numeric).reshape(1, self.max_length, 1)
# 检查结果维度
if result.ndim != 3:
print(f"警告: CNN输入维度异常 - {result.ndim}应为3")
# 强制修正为正确的维度
result = result.reshape(1, self.max_length, 1)
return result
except Exception as e:
print(f"CNN序列处理失败: {str(e)}")
# 出错时返回全零的三维数组
return np.zeros((1, self.max_length, 1))