FScanpy-commit-code/train_models/hist_gb.py

345 lines
12 KiB
Python
Raw Normal View History

2025-08-17 15:30:14 +08:00
"""
HistGradientBoosting Model with MFE Features
"""
import os
import numpy as np
import pandas as pd
import itertools
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.metrics import (
roc_auc_score,
roc_curve,
confusion_matrix,
precision_recall_curve,
average_precision_score
)
from utils.function import evaluate_model_gb
from utils.config import BaseConfig
class GBConfig:
"""HistGradientBoostingClassifier model configuration"""
# Model training parameters
MAX_ITER = 10000
LEARNING_RATE = 0.4
MAX_DEPTH = 5
RANDOM_STATE = 42
# Early stopping parameters
EARLY_STOPPING = True
N_ITER_NO_CHANGE = 10
SCORING = 'loss'
# Sequence parameters
SEQUENCE_LENGTH = 33 # Must be multiple of 3 (codon length)
# Validation parameters
VALIDATION_FRACTION = 0.2
SMALL_VALIDATION_FRACTION = 0.1
def load_data(neg_samples=20000):
"""Load training and validation data"""
try:
train_data = pd.read_csv(os.path.join(BaseConfig.DATA_DIR, "merged_train_data.csv"))
test_data = pd.read_csv(os.path.join(BaseConfig.DATA_DIR, "merged_test_data.csv"))
validation_data = pd.read_csv(os.path.join(BaseConfig.DATA_DIR, "merged_validation_data.csv"))
required_columns = ['full_seq', 'label']
for df in [train_data, test_data, validation_data]:
for col in required_columns:
if col not in df.columns:
if col == 'label':
df[col] = 0
else:
df[col] = ''
xu_data = validation_data[validation_data['source'] == 'Xu'].copy()
atkins_data = validation_data[validation_data['source'] == 'Atkins'].copy()
for df in [xu_data, atkins_data]:
for col in required_columns:
if col not in df.columns:
df[col] = validation_data[col] if col in validation_data.columns else (
0.0 if col in ['mfe_40bp', 'mfe_120bp'] else (
0 if col == 'label' else ''
)
)
return train_data, test_data, validation_data, xu_data, atkins_data
except Exception as e:
return None, None, None, None, None
def train_hist_model(X_train, y_train, X_test, y_test, sample_weights=None,
X_xu=None, y_xu=None, X_atkins=None, y_atkins=None):
"""Train HistGradientBoostingClassifier model"""
# Determine validation fraction
validation_fraction = GBConfig.VALIDATION_FRACTION
if X_xu is not None or X_atkins is not None:
validation_fraction = GBConfig.SMALL_VALIDATION_FRACTION
# Create and train model
model = HistGradientBoostingClassifier(
max_iter=GBConfig.MAX_ITER,
learning_rate=GBConfig.LEARNING_RATE,
max_depth=GBConfig.MAX_DEPTH,
random_state=GBConfig.RANDOM_STATE,
early_stopping=GBConfig.EARLY_STOPPING,
n_iter_no_change=GBConfig.N_ITER_NO_CHANGE,
scoring=GBConfig.SCORING,
validation_fraction=validation_fraction
)
# Train model
model.fit(X_train, y_train, sample_weight=sample_weights)
# Evaluate on test set
test_metrics = evaluate_model_gb(model, X_test, y_test)
# Evaluate on external validation sets
xu_metrics = None
if X_xu is not None and y_xu is not None:
xu_metrics = evaluate_model_gb(model, X_xu, y_xu)
atkins_metrics = None
if X_atkins is not None and y_atkins is not None:
atkins_metrics = evaluate_model_gb(model, X_atkins, y_atkins)
# Prepare training info
training_info = {
'n_iter': model.n_iter_,
'train_score': model.train_score_,
'validation_scores': model.validation_scores_ if hasattr(model, 'validation_scores_') else None,
'final_metrics': {
'test': test_metrics,
'xu': xu_metrics,
'atkins': atkins_metrics
}
}
return model, test_metrics, training_info
def get_feature_names(seq_length=33):
"""Return feature names including all possible base features and MFE features"""
features = []
# Single nucleotide features
bases = ['A', 'T', 'C', 'G']
for i in range(seq_length):
for base in bases:
features.append(f'pos_{i+1}_{base}')
# Dinucleotide features
dinucleotides = [''.join(pair) for pair in itertools.product(bases, repeat=2)]
for i in range(seq_length - 1):
for dinuc in dinucleotides:
features.append(f'dinuc_{i+1}_{dinuc}')
# Trinucleotide (codon) features
trinucleotides = [''.join(triplet) for triplet in itertools.product(bases, repeat=3)]
for i in range(seq_length - 2):
for trinuc in trinucleotides:
features.append(f'codon_{i+1}_{trinuc}')
# MFE features
features.extend(['mfe_40bp', 'mfe_120bp'])
return features
def trim_sequence(seq, target_length):
"""Trim sequence from both ends to reach target length, keeping center position"""
if len(seq) <= target_length:
return seq
excess = len(seq) - target_length
left_trim = excess // 2
right_trim = excess - left_trim
return seq[left_trim:len(seq)-right_trim]
def sequence_to_features(sequence, seq_length=33, mfe_values=None):
"""Convert DNA sequence to feature vector including MFE features"""
# Trim sequence to target length
trimmed_seq = trim_sequence(sequence.upper(), seq_length)
# Initialize feature vector
feature_vector = []
# Single nucleotide features (one-hot encoding)
bases = ['A', 'T', 'C', 'G']
for i in range(seq_length):
for base in bases:
if i < len(trimmed_seq) and trimmed_seq[i] == base:
feature_vector.append(1)
else:
feature_vector.append(0)
# Dinucleotide features
dinucleotides = [''.join(pair) for pair in itertools.product(bases, repeat=2)]
for i in range(seq_length - 1):
for dinuc in dinucleotides:
if i + 1 < len(trimmed_seq) and trimmed_seq[i:i+2] == dinuc:
feature_vector.append(1)
else:
feature_vector.append(0)
# Trinucleotide (codon) features
trinucleotides = [''.join(triplet) for triplet in itertools.product(bases, repeat=3)]
for i in range(seq_length - 2):
for trinuc in trinucleotides:
if i + 2 < len(trimmed_seq) and trimmed_seq[i:i+3] == trinuc:
feature_vector.append(1)
else:
feature_vector.append(0)
# Add MFE features
if mfe_values is not None:
if isinstance(mfe_values, dict):
feature_vector.append(mfe_values.get('mfe_40bp', 0.0))
feature_vector.append(mfe_values.get('mfe_120bp', 0.0))
elif isinstance(mfe_values, (list, tuple)) and len(mfe_values) >= 2:
feature_vector.extend(mfe_values[:2])
else:
feature_vector.extend([0.0, 0.0])
else:
feature_vector.extend([0.0, 0.0])
return np.array(feature_vector)
def prepare_data(train_data, test_data, seq_length=33):
"""Prepare training and test data including MFE features"""
# Process training data
X_train = []
y_train = []
sample_weights = []
for _, row in train_data.iterrows():
sequence = row['full_seq']
label = row['label']
# Get MFE values
mfe_values = {}
if 'mfe_40bp' in row:
mfe_values['mfe_40bp'] = row['mfe_40bp'] if pd.notna(row['mfe_40bp']) else 0.0
if 'mfe_120bp' in row:
mfe_values['mfe_120bp'] = row['mfe_120bp'] if pd.notna(row['mfe_120bp']) else 0.0
# Convert to features
features = sequence_to_features(sequence, seq_length, mfe_values)
X_train.append(features)
y_train.append(label)
# Sample weight
weight = 1.0
if 'sample_weight' in row and pd.notna(row['sample_weight']):
weight = row['sample_weight']
sample_weights.append(weight)
X_train = np.array(X_train)
y_train = np.array(y_train)
sample_weights = np.array(sample_weights)
# Process test data
X_test = []
y_test = []
if test_data is not None and not test_data.empty:
for _, row in test_data.iterrows():
sequence = row['full_seq']
label = row['label']
# Get MFE values
mfe_values = {}
if 'mfe_40bp' in row:
mfe_values['mfe_40bp'] = row['mfe_40bp'] if pd.notna(row['mfe_40bp']) else 0.0
if 'mfe_120bp' in row:
mfe_values['mfe_120bp'] = row['mfe_120bp'] if pd.notna(row['mfe_120bp']) else 0.0
# Convert to features
features = sequence_to_features(sequence, seq_length, mfe_values)
X_test.append(features)
y_test.append(label)
X_test = np.array(X_test) if X_test else None
y_test = np.array(y_test) if y_test else None
return X_train, y_train, X_test, y_test, sample_weights, train_data, test_data
def analyze_feature_importance(model, X_test, y_test, test_data):
"""Analyze feature importance (simplified version)"""
try:
# Get feature names
feature_names = get_feature_names(GBConfig.SEQUENCE_LENGTH)
# Built-in feature importance
if hasattr(model, 'feature_importances_'):
importance_scores = model.feature_importances_
# Create importance DataFrame
importance_df = pd.DataFrame({
'feature': feature_names,
'importance': importance_scores
}).sort_values('importance', ascending=False)
# Save results
importance_path = os.path.join(BaseConfig.GB_DIR, 'feature_importance.csv')
importance_df.to_csv(importance_path, index=False)
return {'built_in_importance': importance_df}
return None
except Exception as e:
return None
def main():
"""Main training function"""
try:
# Set sequence length
sequence_length = GBConfig.SEQUENCE_LENGTH
# Load data
train_data, test_data, _, xu_data, atkins_data = load_data()
# Prepare data
X_train, y_train, X_test, y_test, sample_weights, _, _ = prepare_data(
train_data, test_data, seq_length=sequence_length
)
# Prepare validation data
X_xu = y_xu = X_atkins = y_atkins = None
if xu_data is not None and not xu_data.empty:
try:
empty_test = pd.DataFrame(columns=xu_data.columns)
X_xu, y_xu, _, _, _, _, _ = prepare_data(xu_data, empty_test, seq_length=sequence_length)
except Exception as e:
X_xu = y_xu = None
if atkins_data is not None and not atkins_data.empty:
try:
empty_test = pd.DataFrame(columns=atkins_data.columns)
X_atkins, y_atkins, _, _, _, _, _ = prepare_data(atkins_data, empty_test, seq_length=sequence_length)
except Exception as e:
X_atkins = y_atkins = None
# Train model
model, _, training_info = train_hist_model(
X_train, y_train, X_test, y_test, sample_weights,
X_xu=X_xu, y_xu=y_xu, X_atkins=X_atkins, y_atkins=y_atkins
)
# Feature importance analysis
source_results = analyze_feature_importance(model, X_test, y_test, test_data)
return model, training_info['final_metrics']['test']
except Exception as e:
return None, None
if __name__ == "__main__":
BaseConfig.create_directories()
main()