FScanpy-commit-code/train_models/bilstm_cnn.py

508 lines
18 KiB
Python

"""
BiLSTM-CNN Model for Sequence Classification
"""
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score, precision_score, recall_score, log_loss
from tensorflow.keras import layers, models
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.sequence import pad_sequences
from utils.function import load_data, save_training_info, select_low_confidence_samples_cnn, evaluate_model_cnn
from utils.config import BaseConfig
# Set random seeds
np.random.seed(42)
tf.random.set_seed(42)
class MetricsCallback(tf.keras.callbacks.Callback):
"""Callback for recording training metrics"""
def __init__(self):
super().__init__()
self.training_metrics = {
'train_loss': [], 'train_auc': [], 'train_accuracy': [],
'train_recall': [], 'train_precision': [], 'train_f1': [],
'test_loss': [], 'test_auc': [], 'test_accuracy': [],
'test_recall': [], 'test_precision': [], 'test_f1': []
}
self.iteration_metrics = {
'samples_added': [0],
'total_samples': []
}
self.best_model = None
self.best_test_loss = float('inf')
self.best_epoch = -1
self.best_predictions = None
self.xu_metrics_history = {
'loss': [], 'auc': [], 'accuracy': [],
'recall': [], 'precision': [], 'f1': []
}
self.atkins_metrics_history = {
'loss': [], 'auc': [], 'accuracy': [],
'recall': [], 'precision': [], 'f1': []
}
self.self_training_best_model = None
self.self_training_best_loss = float('inf')
self.self_training_best_metrics = None
def on_epoch_end(self, epoch, logs={}):
try:
train_loss = logs.get('loss', 0.0)
val_loss = logs.get('val_loss', 0.0)
train_metrics = {
'loss': train_loss,
'auc': logs.get('auc', 0.0),
'accuracy': logs.get('accuracy', 0.0),
'recall': logs.get('recall', 0.0),
'precision': 0.0,
'f1': 0.0
}
# Calculate test metrics using batch processing
batch_size = 128
n_test_samples = len(self.model.X_test)
n_test_batches = (n_test_samples + batch_size - 1) // batch_size
test_probs = np.zeros(n_test_samples)
try:
for i in range(n_test_batches):
start_idx = i * batch_size
end_idx = min((i + 1) * batch_size, n_test_samples)
batch_probs = self.model.predict(self.model.X_test[start_idx:end_idx], verbose=0)
if isinstance(batch_probs, list):
batch_probs = batch_probs[0]
if len(batch_probs.shape) > 1:
batch_probs = batch_probs.flatten()
test_probs[start_idx:end_idx] = batch_probs
test_preds = (test_probs > 0.5).astype(int)
test_metrics = {
'loss': log_loss(self.model.y_test, np.clip(test_probs, 1e-15, 1-1e-15)),
'auc': roc_auc_score(self.model.y_test, test_probs) if len(np.unique(test_probs)) > 1 else 0.5,
'accuracy': accuracy_score(self.model.y_test, test_preds),
'recall': recall_score(self.model.y_test, test_preds, zero_division=0),
'precision': precision_score(self.model.y_test, test_preds, zero_division=0),
'f1': f1_score(self.model.y_test, test_preds, zero_division=0)
}
except Exception as e:
test_metrics = {
'loss': float('inf'), 'auc': 0.0, 'accuracy': 0.0,
'recall': 0.0, 'precision': 0.0, 'f1': 0.0
}
# Record metrics
for key in self.training_metrics:
if key.startswith('train_'):
metric_name = key[6:]
self.training_metrics[key].append(train_metrics.get(metric_name, 0.0))
elif key.startswith('test_'):
metric_name = key[5:]
self.training_metrics[key].append(test_metrics.get(metric_name, 0.0))
# Update best model based on test loss
if test_metrics['loss'] < self.best_test_loss:
self.best_test_loss = test_metrics['loss']
self.best_epoch = epoch
self.best_model = tf.keras.models.clone_model(self.model)
self.best_model.set_weights(self.model.get_weights())
self.best_predictions = test_probs.copy()
# Evaluate external validation sets if available
if hasattr(self.model, 'X_xu') and self.model.X_xu is not None:
xu_metrics = evaluate_model_cnn(self.model, self.model.X_xu, self.model.y_xu)
for key in self.xu_metrics_history:
self.xu_metrics_history[key].append(xu_metrics.get(key, 0.0))
if hasattr(self.model, 'X_atkins') and self.model.X_atkins is not None:
atkins_metrics = evaluate_model_cnn(self.model, self.model.X_atkins, self.model.y_atkins)
for key in self.atkins_metrics_history:
self.atkins_metrics_history[key].append(atkins_metrics.get(key, 0.0))
except Exception as e:
pass
def on_train_end(self, logs=None):
if self.best_model is not None:
self.model.set_weights(self.best_model.get_weights())
class Config:
"""Model configuration parameters"""
NEG_SAMPLES = 20000
CONFIDENCE_THRESHOLD = 0.5
EMBEDDING_DIM = 64
LSTM_UNITS = 64
CNN_FILTERS = 64
CNN_KERNEL_SIZES = [3, 5, 7]
DROPOUT_RATE = 0.5
LEARNING_RATE = 1e-4
BATCH_SIZE = 1024
EPOCHS = 5
INITIAL_EPOCHS = 5
SELF_TRAINING_EPOCHS = 1
MAX_ITERATIONS = 20
EARLY_STOPPING_PATIENCE = 5
Sequence_len = 399
def process_sequence(seq, max_length=399):
"""Process single sequence"""
return seq[:max_length] if len(seq) > max_length else seq
def encode_sequence(seq, max_length=399):
"""Encode single sequence"""
mapping = {'A': 1, 'T': 2, 'C': 3, 'G': 4}
encoded = [mapping.get(base, 0) for base in seq.upper()]
if len(encoded) < max_length:
encoded.extend([0] * (max_length - len(encoded)))
return encoded[:max_length]
def trim_sequence(seq, target_length):
"""Trim sequence from both ends to reach target length"""
if len(seq) <= target_length:
return seq
excess = len(seq) - target_length
left_trim = excess // 2
right_trim = excess - left_trim
return seq[left_trim:len(seq)-right_trim]
def prepare_data(train_data, test_data=None, low_conf_data=None, max_length=399):
"""Prepare training and test data"""
# Process training data
train_sequences = []
train_labels = []
sample_weights = []
for _, row in train_data.iterrows():
seq = process_sequence(row['full_seq'], max_length)
encoded_seq = encode_sequence(seq, max_length)
train_sequences.append(encoded_seq)
train_labels.append(row['label'])
weight = 1.0
if 'sample_weight' in row and pd.notna(row['sample_weight']):
weight = row['sample_weight']
sample_weights.append(weight)
X_train = np.array(train_sequences)
y_train = np.array(train_labels)
sample_weights = np.array(sample_weights)
# Process test data
X_test = y_test = None
if test_data is not None and not test_data.empty:
test_sequences = []
test_labels = []
for _, row in test_data.iterrows():
seq = process_sequence(row['full_seq'], max_length)
encoded_seq = encode_sequence(seq, max_length)
test_sequences.append(encoded_seq)
test_labels.append(row['label'])
X_test = np.array(test_sequences)
y_test = np.array(test_labels)
# Process low confidence data
X_low_conf = y_low_conf = None
if low_conf_data is not None and not low_conf_data.empty:
low_conf_sequences = []
low_conf_labels = []
for _, row in low_conf_data.iterrows():
seq = process_sequence(row['full_seq'], max_length)
encoded_seq = encode_sequence(seq, max_length)
low_conf_sequences.append(encoded_seq)
low_conf_labels.append(row['label'])
X_low_conf = np.array(low_conf_sequences)
y_low_conf = np.array(low_conf_labels)
return X_train, y_train, X_test, y_test, sample_weights, X_low_conf, y_low_conf
def create_bilstm_cnn_model(input_shape):
"""Create BiLSTM-CNN model"""
input_layer = layers.Input(shape=input_shape)
# Embedding layer
embedding = layers.Embedding(
input_dim=5,
output_dim=Config.EMBEDDING_DIM,
input_length=input_shape[0]
)(input_layer)
# BiLSTM layers
lstm_out = layers.Bidirectional(
layers.LSTM(Config.LSTM_UNITS, return_sequences=True, dropout=Config.DROPOUT_RATE)
)(embedding)
# CNN branches
cnn_outputs = []
for kernel_size in Config.CNN_KERNEL_SIZES:
cnn = layers.Conv1D(
filters=Config.CNN_FILTERS,
kernel_size=kernel_size,
activation='relu',
padding='same'
)(lstm_out)
cnn = layers.GlobalMaxPooling1D()(cnn)
cnn_outputs.append(cnn)
# Concatenate CNN outputs
if len(cnn_outputs) > 1:
concat = layers.Concatenate()(cnn_outputs)
else:
concat = cnn_outputs[0]
# Dense layers
dense = layers.Dense(128, activation='relu')(concat)
dense = layers.Dropout(Config.DROPOUT_RATE)(dense)
dense = layers.Dense(64, activation='relu')(dense)
dense = layers.Dropout(Config.DROPOUT_RATE)(dense)
# Output layer
output = layers.Dense(1, activation='sigmoid')(dense)
model = models.Model(inputs=input_layer, outputs=output)
# Compile model
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=Config.LEARNING_RATE),
loss='binary_crossentropy',
metrics=['accuracy', 'auc', 'recall']
)
return model
def train_bilstm_cnn_model(X_train, y_train, X_test, y_test, sample_weights=None,
X_xu=None, y_xu=None, X_atkins=None, y_atkins=None):
"""Train BiLSTM-CNN model with self-training"""
# Create model
input_shape = (X_train.shape[1],)
model = create_bilstm_cnn_model(input_shape)
# Store validation data in model for callback access
model.X_test = X_test
model.y_test = y_test
model.X_xu = X_xu
model.y_xu = y_xu
model.X_atkins = X_atkins
model.y_atkins = y_atkins
# Initial training
metrics_callback = MetricsCallback()
early_stopping = tf.keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=Config.EARLY_STOPPING_PATIENCE,
restore_best_weights=True,
verbose=0
)
# Split training data for validation
val_split = 0.2
n_val = int(len(X_train) * val_split)
indices = np.random.permutation(len(X_train))
train_indices = indices[n_val:]
val_indices = indices[:n_val]
X_train_split = X_train[train_indices]
y_train_split = y_train[train_indices]
X_val_split = X_train[val_indices]
y_val_split = y_train[val_indices]
if sample_weights is not None:
sample_weights_split = sample_weights[train_indices]
else:
sample_weights_split = None
# Initial training
model.fit(
X_train_split, y_train_split,
validation_data=(X_val_split, y_val_split),
epochs=Config.INITIAL_EPOCHS,
batch_size=Config.BATCH_SIZE,
sample_weight=sample_weights_split,
callbacks=[metrics_callback, early_stopping],
verbose=0
)
# Store initial training info
initial_info = {
'best_test_loss': metrics_callback.best_test_loss,
'best_epoch': metrics_callback.best_epoch,
'training_metrics': metrics_callback.training_metrics.copy()
}
# Self-training iterations
current_X_train = X_train.copy()
current_y_train = y_train.copy()
current_weights = sample_weights.copy() if sample_weights is not None else None
iteration_metrics = {
'iteration': [0],
'train_loss': [metrics_callback.training_metrics['train_loss'][-1]],
'test_loss': [metrics_callback.training_metrics['test_loss'][-1]],
'samples_added': [0],
'total_samples': [len(current_X_train)]
}
if X_xu is not None:
xu_metrics = evaluate_model_cnn(model, X_xu, y_xu)
iteration_metrics['xu_loss'] = [xu_metrics['loss']]
if X_atkins is not None:
atkins_metrics = evaluate_model_cnn(model, X_atkins, y_atkins)
iteration_metrics['atkins_loss'] = [atkins_metrics['loss']]
best_model = tf.keras.models.clone_model(model)
best_model.set_weights(model.get_weights())
best_loss = metrics_callback.best_test_loss
best_iteration = 0
# Load low confidence data for self-training
_, _, low_conf_data, _, _ = load_data()
if low_conf_data is not None and not low_conf_data.empty:
X_unlabeled, _, _, _, _, _, _ = prepare_data(
low_conf_data, pd.DataFrame(), max_length=Config.Sequence_len
)
for iteration in range(1, Config.MAX_ITERATIONS + 1):
# Select low confidence samples
selected_samples = select_low_confidence_samples_cnn(
model, X_unlabeled, low_conf_data
)
if selected_samples.empty:
break
# Prepare selected samples
X_selected, y_selected, _, _, weights_selected, _, _ = prepare_data(
selected_samples, pd.DataFrame(), max_length=Config.Sequence_len
)
if len(X_selected) == 0:
break
# Add to training set
current_X_train = np.vstack([current_X_train, X_selected])
current_y_train = np.hstack([current_y_train, y_selected])
if current_weights is not None:
current_weights = np.hstack([current_weights, weights_selected])
# Retrain model
metrics_callback = MetricsCallback()
# Split updated training data
n_val = int(len(current_X_train) * val_split)
indices = np.random.permutation(len(current_X_train))
train_indices = indices[n_val:]
val_indices = indices[:n_val]
X_train_split = current_X_train[train_indices]
y_train_split = current_y_train[train_indices]
X_val_split = current_X_train[val_indices]
y_val_split = current_y_train[val_indices]
if current_weights is not None:
sample_weights_split = current_weights[train_indices]
else:
sample_weights_split = None
model.fit(
X_train_split, y_train_split,
validation_data=(X_val_split, y_val_split),
epochs=Config.SELF_TRAINING_EPOCHS,
batch_size=Config.BATCH_SIZE,
sample_weight=sample_weights_split,
callbacks=[metrics_callback, early_stopping],
verbose=0
)
# Record iteration metrics
iteration_metrics['iteration'].append(iteration)
iteration_metrics['train_loss'].append(metrics_callback.training_metrics['train_loss'][-1])
iteration_metrics['test_loss'].append(metrics_callback.training_metrics['test_loss'][-1])
iteration_metrics['samples_added'].append(len(X_selected))
iteration_metrics['total_samples'].append(len(current_X_train))
if X_xu is not None:
xu_metrics = evaluate_model_cnn(model, X_xu, y_xu)
iteration_metrics['xu_loss'].append(xu_metrics['loss'])
if X_atkins is not None:
atkins_metrics = evaluate_model_cnn(model, X_atkins, y_atkins)
iteration_metrics['atkins_loss'].append(atkins_metrics['loss'])
# Update best model
current_loss = metrics_callback.training_metrics['test_loss'][-1]
if current_loss < best_loss:
best_model = tf.keras.models.clone_model(model)
best_model.set_weights(model.get_weights())
best_loss = current_loss
best_iteration = iteration
# Final evaluation
final_metrics = evaluate_model_cnn(best_model, X_test, y_test)
training_info = {
'initial_info': initial_info,
'iteration_metrics': iteration_metrics,
'best_iteration': best_iteration,
'final_metrics': final_metrics
}
return best_model, model, training_info
def main():
"""Main training function"""
# Load data
train_data, test_data, low_conf_data, xu_data, atkins_data = load_data()
# Prepare data
X_train, y_train, X_test, y_test, sample_weights, _, _ = prepare_data(
train_data, test_data, max_length=Config.Sequence_len
)
# Prepare validation data
X_xu = y_xu = X_atkins = y_atkins = None
if xu_data is not None and not xu_data.empty:
X_xu, y_xu, _, _, _, _, _ = prepare_data(
xu_data, pd.DataFrame(), max_length=Config.Sequence_len
)
if atkins_data is not None and not atkins_data.empty:
X_atkins, y_atkins, _, _, _, _, _ = prepare_data(
atkins_data, pd.DataFrame(), max_length=Config.Sequence_len
)
# Train model
best_model, final_model, training_info = train_bilstm_cnn_model(
X_train, y_train, X_test, y_test, sample_weights,
X_xu=X_xu, y_xu=y_xu, X_atkins=X_atkins, y_atkins=y_atkins
)
# Save results
save_training_info(best_model, training_info, BaseConfig.BILSTM_MODEL_DIR, "best")
save_training_info(final_model, training_info, BaseConfig.BILSTM_MODEL_DIR, "final", is_final_model=True)
return best_model, final_model, training_info
if __name__ == "__main__":
BaseConfig.create_directories()
main()