FScanpy-package/FScanpy/features/cnn_input.py

80 lines
2.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
from tensorflow.keras.preprocessing.sequence import pad_sequences
from typing import List, Union
class CNNInputProcessor:
"""CNN模型输入数据处理器"""
def __init__(self, max_length: int = 399):
self.max_length = max_length
self.base_to_num = {'A': 0, 'T': 1, 'C': 2, 'G': 3, 'N': 4}
def trim_sequence(self, seq, target_length):
"""
从序列两端等量截取,使其达到目标长度,保持中心位置不变
参数:
seq: 原始序列
target_length: 目标长度
返回:
截取后的序列
"""
if len(seq) <= target_length:
return seq
# 计算需要从每端截取的长度
excess = len(seq) - target_length
trim_each_side = excess // 2
# 从两端等量截取,保持中心位置不变
return seq[trim_each_side:len(seq)-trim_each_side]
def prepare_sequence(self, sequence: str) -> np.ndarray:
"""
处理单个序列
Args:
sequence: DNA序列
Returns:
np.ndarray: 处理后的序列数组
"""
try:
# 序列验证和预处理
if not isinstance(sequence, str):
sequence = str(sequence)
sequence = sequence.upper().replace('U', 'T')
# 如果序列长度不等于目标长度,进行截取
if len(sequence) > self.max_length:
sequence = self.trim_sequence(sequence, self.max_length)
# 使用与训练时相同的编码方式
self.base_to_num = {'A': 0, 'T': 1, 'C': 2, 'G': 3, 'N': 4} # 与SemiBilstmCnn.py中保持一致
# 序列转换为数字
seq_numeric = []
for base in sequence:
seq_numeric.append(self.base_to_num.get(base, 4)) # 未知碱基用4表示
# 填充序列
if len(seq_numeric) < self.max_length:
seq_numeric.extend([4] * (self.max_length - len(seq_numeric)))
# 重塑数据为三维数组 (samples, timesteps, features)
result = np.array(seq_numeric).reshape(1, self.max_length, 1)
# 检查结果维度
if result.ndim != 3:
print(f"警告: CNN输入维度异常 - {result.ndim}应为3")
# 强制修正为正确的维度
result = result.reshape(1, self.max_length, 1)
return result
except Exception as e:
print(f"CNN序列处理失败: {str(e)}")
# 出错时返回全零的三维数组
return np.zeros((1, self.max_length, 1))