FScanpy-commit-code/model_feature/feature_analysis.py

402 lines
14 KiB
Python

"""
Feature Importance Analysis for FScanpy Models
"""
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.inspection import permutation_importance
import shap
import pickle
import joblib
import sys
from utils.config import BaseConfig
from utils.function import load_data
from models.hist_gb import get_feature_names, sequence_to_features, GBConfig
from models.bilstm_cnn import Config as BiLSTMConfig, process_sequence, encode_sequence, trim_sequence
class FeatureImportanceAnalyzer:
"""Feature importance analyzer class providing multiple analysis methods"""
def __init__(self, gb_model_path, bilstm_model_path, output_dir="feature_importance_results"):
"""
Initialize feature importance analyzer
Args:
gb_model_path: HistGradientBoosting model path
bilstm_model_path: BiLSTM-CNN model path
output_dir: Output directory for results
"""
self.output_dir = output_dir
os.makedirs(output_dir, exist_ok=True)
# Create subdirectories
self.gb_dir = os.path.join(output_dir, "gb_model")
self.bilstm_dir = os.path.join(output_dir, "bilstm_model")
self.combined_dir = os.path.join(output_dir, "combined_analysis")
os.makedirs(self.gb_dir, exist_ok=True)
os.makedirs(self.bilstm_dir, exist_ok=True)
os.makedirs(self.combined_dir, exist_ok=True)
# Load models
self.gb_model = self._load_gb_model(gb_model_path)
self.bilstm_model = self._load_bilstm_model(bilstm_model_path)
# Load data
self.train_data, self.test_data, _, self.xu_data, self.atkins_data = load_data()
# Use test set for analysis
self.validation_data = self.test_data
# Separate frameshift and non-frameshift samples
self.fs_samples = self.validation_data[self.validation_data['label'] == 1]
self.nonfs_samples = self.validation_data[self.validation_data['label'] == 0]
# Define sequence lengths
self.gb_seq_length = GBConfig.SEQUENCE_LENGTH
self.bilstm_seq_length = BiLSTMConfig.Sequence_len
# Record analysis results
self.gb_results = {}
self.bilstm_results = {}
def _load_gb_model(self, model_path):
"""Load HistGradientBoosting model"""
try:
model = joblib.load(model_path)
return model
except Exception as e:
return None
def _load_bilstm_model(self, model_path):
"""Load BiLSTM-CNN model"""
try:
model = tf.keras.models.load_model(model_path)
return model
except Exception as e:
return None
def prepare_gb_data(self, data):
"""
Prepare feature data for HistGradientBoosting model
Args:
data: DataFrame containing sequences
Returns:
features: Feature matrix
labels: Label values (if data contains label column)
sequences: Processed sequence list
"""
features = []
labels = []
sequences = []
for _, row in data.iterrows():
sequence = row['full_seq']
# Get MFE values
mfe_values = {}
if 'mfe_40bp' in row:
mfe_values['mfe_40bp'] = row['mfe_40bp'] if pd.notna(row['mfe_40bp']) else 0.0
if 'mfe_120bp' in row:
mfe_values['mfe_120bp'] = row['mfe_120bp'] if pd.notna(row['mfe_120bp']) else 0.0
# Convert to features
feature_vector = sequence_to_features(sequence, self.gb_seq_length, mfe_values)
features.append(feature_vector)
sequences.append(sequence)
if 'label' in row:
labels.append(row['label'])
features = np.array(features)
labels = np.array(labels) if len(labels) > 0 else None
return features, labels, sequences
def prepare_bilstm_data(self, data):
"""
Prepare feature data for BiLSTM-CNN model
Args:
data: DataFrame containing sequences
Returns:
encoded_sequences: Encoded feature tensor
original_sequences: Original sequence list
"""
sequences = []
for _, row in data.iterrows():
sequences.append(row['full_seq'])
# Encode sequences
encoded = [encode_sequence(process_sequence(seq, self.bilstm_seq_length), self.bilstm_seq_length) for seq in sequences]
encoded = np.array(encoded)
return encoded, sequences
def gb_built_in_importance(self, use_positive_only=True):
"""
Feature importance analysis using HistGradientBoostingClassifier built-in importance
Returns:
importance_df: Feature importance DataFrame
"""
try:
data = self.fs_samples if use_positive_only else self.validation_data
X, y, _ = self.prepare_gb_data(data)
if self.gb_model is None or X is None or len(X) == 0:
return None
# Built-in feature importance
if hasattr(self.gb_model, 'feature_importances_'):
feature_names = get_feature_names(self.gb_seq_length)
importance_scores = self.gb_model.feature_importances_
importance_df = pd.DataFrame({
'feature': feature_names,
'importance': importance_scores
}).sort_values('importance', ascending=False)
# Save results
output_path = os.path.join(self.gb_dir, 'built_in_importance.csv')
importance_df.to_csv(output_path, index=False)
self.gb_results['built_in_importance'] = importance_df
return importance_df
return None
except Exception as e:
return None
def gb_permutation_importance(self, n_repeats=10, use_positive_only=True):
"""
Permutation importance analysis for HistGradientBoosting model
Args:
n_repeats: Number of repeats
use_positive_only: Whether to use only positive samples
Returns:
Feature importance DataFrame
"""
try:
data = self.fs_samples if use_positive_only else self.validation_data
X, y, _ = self.prepare_gb_data(data)
if self.gb_model is None or X is None or y is None:
return None
result = permutation_importance(self.gb_model, X, y, n_repeats=n_repeats, random_state=42)
feature_names = get_feature_names(self.gb_seq_length)
importance_df = pd.DataFrame({
'feature': feature_names,
'importance': result.importances_mean,
'std': result.importances_std
}).sort_values('importance', ascending=False)
# Save results
output_path = os.path.join(self.gb_dir, 'permutation_importance.csv')
importance_df.to_csv(output_path, index=False)
self.gb_results['permutation_importance'] = importance_df
return importance_df
except Exception as e:
return None
def gb_shap_analysis(self, max_samples=100, use_positive_only=True):
"""
SHAP analysis for HistGradientBoosting model
Args:
max_samples: Maximum number of samples
use_positive_only: Whether to use only positive samples
Returns:
SHAP values DataFrame
"""
try:
data = self.fs_samples if use_positive_only else self.validation_data
data = data.sample(n=min(max_samples, len(data)), random_state=42) if len(data) > max_samples else data
X, y, _ = self.prepare_gb_data(data)
if self.gb_model is None or X is None:
return None
explainer = shap.Explainer(self.gb_model.predict, X)
shap_values = explainer(X)
feature_names = get_feature_names(self.gb_seq_length)
importance = np.abs(shap_values.values).mean(axis=0)
shap_df = pd.DataFrame({
'feature': feature_names,
'shap_importance': importance
}).sort_values('shap_importance', ascending=False)
# Save results
output_path = os.path.join(self.gb_dir, 'shap_importance.csv')
shap_df.to_csv(output_path, index=False)
self.gb_results['shap_importance'] = shap_df
return shap_df
except Exception as e:
return None
def bilstm_integrated_gradients(self, n_samples=20, steps=50, use_positive_only=True):
"""
Integrated gradients analysis for BiLSTM-CNN model
Args:
n_samples: Number of samples for analysis
steps: Number of integration steps
use_positive_only: Whether to use only positive samples
Returns:
Average importance scores DataFrame
"""
try:
data = self.fs_samples if use_positive_only else self.validation_data
data = data.sample(n=min(n_samples, len(data)), random_state=42) if len(data) > n_samples else data
X, sequences = self.prepare_bilstm_data(data)
if self.bilstm_model is None or X is None or len(X) == 0:
return None
# Create baseline (all zeros)
baseline = np.zeros_like(X, dtype=np.float32)
X = X.astype(np.float32)
# Compute integrated gradients
ig_values = self.integrated_gradients(self.bilstm_model, X, baseline, steps=steps)
# Aggregate across sequences and positions
avg_importance = ig_values.mean(axis=0)
positions = list(range(1, self.bilstm_seq_length + 1))
importance_df = pd.DataFrame({
'position': positions,
'importance': avg_importance
})
# Save results
output_path = os.path.join(self.bilstm_dir, 'integrated_gradients.csv')
importance_df.to_csv(output_path, index=False)
self.bilstm_results['integrated_gradients'] = importance_df
return importance_df
except Exception as e:
return None
def integrated_gradients(self, model, inputs, baseline, steps=50):
"""
Calculate integrated gradients
Args:
model: Model
inputs: Input samples
baseline: Baseline input
steps: Integration steps
Returns:
Integrated gradient values
"""
# Generate interpolated inputs
alphas = np.linspace(0.0, 1.0, steps + 1)
gradients = []
for alpha in alphas:
interpolated = baseline + alpha * (inputs - baseline)
with tf.GradientTape() as tape:
tape.watch(interpolated)
predictions = model(interpolated)
grads = tape.gradient(predictions, interpolated)
gradients.append(grads.numpy())
# Calculate integrated gradients
gradients = np.array(gradients)
integrated_grads = np.mean(gradients, axis=0) * (inputs - baseline)
return integrated_grads
def run_all_analyses(self, max_samples=20, use_positive_only=True):
"""
Run all feature importance analysis methods
Args:
max_samples: Maximum number of samples
use_positive_only: Whether to use only positive samples
"""
results = {}
# GB model analyses
if self.gb_model is not None:
results['gb_built_in'] = self.gb_built_in_importance(use_positive_only)
results['gb_permutation'] = self.gb_permutation_importance(use_positive_only=use_positive_only)
results['gb_shap'] = self.gb_shap_analysis(max_samples, use_positive_only)
# BiLSTM model analyses
if self.bilstm_model is not None:
results['bilstm_integrated_gradients'] = self.bilstm_integrated_gradients(
max_samples, use_positive_only=use_positive_only
)
# Save combined results summary
self._save_analysis_summary(results)
return results
def _save_analysis_summary(self, results):
"""Save analysis summary"""
try:
summary = {}
for method, result in results.items():
if result is not None:
summary[method] = {
'status': 'completed',
'num_features': len(result) if hasattr(result, '__len__') else 'N/A'
}
else:
summary[method] = {'status': 'failed'}
# Save summary
summary_path = os.path.join(self.combined_dir, 'analysis_summary.json')
import json
with open(summary_path, 'w') as f:
json.dump(summary, f, indent=2)
except Exception as e:
pass
def main():
"""Main analysis function"""
try:
# Model paths (virtual paths)
gb_model_path = os.path.join(BaseConfig.GB_MODEL_DIR, "best_model.joblib")
bilstm_model_path = os.path.join(BaseConfig.BILSTM_MODEL_DIR, "best_model.h5")
output_dir = os.path.join(BaseConfig.RESULT_DIR, "feature_importance")
# Create analyzer
analyzer = FeatureImportanceAnalyzer(
gb_model_path,
bilstm_model_path,
output_dir
)
# Run all analyses
analyzer.run_all_analyses(max_samples=20, use_positive_only=True)
except Exception as e:
pass
if __name__ == "__main__":
main()