FScanpy-commit-code/model_feature/feature_analysis.py

427 lines
15 KiB
Python
Raw Permalink Normal View History

2025-08-17 15:32:55 +08:00
"""
Feature Importance Analysis for FScanpy Models
"""
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.inspection import permutation_importance
import shap
import pickle
import joblib
import sys
from utils.config import BaseConfig
from utils.function import load_data
from models.hist_gb import get_feature_names, sequence_to_features, GBConfig
from models.bilstm_cnn import Config as BiLSTMConfig, process_sequence, encode_sequence, trim_sequence
class FeatureImportanceAnalyzer:
"""Feature importance analyzer class providing multiple analysis methods"""
def __init__(self, gb_model_path, bilstm_model_path, output_dir="feature_importance_results"):
"""
Initialize feature importance analyzer
Args:
gb_model_path: HistGradientBoosting model path
bilstm_model_path: BiLSTM-CNN model path
output_dir: Output directory for results
"""
self.output_dir = output_dir
os.makedirs(output_dir, exist_ok=True)
# Create subdirectories
self.gb_dir = os.path.join(output_dir, "gb_model")
self.bilstm_dir = os.path.join(output_dir, "bilstm_model")
self.combined_dir = os.path.join(output_dir, "combined_analysis")
os.makedirs(self.gb_dir, exist_ok=True)
os.makedirs(self.bilstm_dir, exist_ok=True)
os.makedirs(self.combined_dir, exist_ok=True)
# Load models
self.gb_model = self._load_gb_model(gb_model_path)
self.bilstm_model = self._load_bilstm_model(bilstm_model_path)
# Load data
self.train_data, self.test_data, _, self.xu_data, self.atkins_data = load_data()
# Use test set for analysis
self.validation_data = self.test_data
# Separate frameshift and non-frameshift samples
self.fs_samples = self.validation_data[self.validation_data['label'] == 1]
self.nonfs_samples = self.validation_data[self.validation_data['label'] == 0]
# Define sequence lengths
self.gb_seq_length = GBConfig.SEQUENCE_LENGTH
self.bilstm_seq_length = BiLSTMConfig.Sequence_len
# Record analysis results
self.gb_results = {}
self.bilstm_results = {}
def _load_gb_model(self, model_path):
"""Load HistGradientBoosting model"""
try:
model = joblib.load(model_path)
return model
except Exception as e:
return None
def _load_bilstm_model(self, model_path):
"""Load BiLSTM-CNN model"""
try:
model = tf.keras.models.load_model(model_path)
return model
except Exception as e:
return None
def prepare_gb_data(self, data):
"""
Prepare feature data for HistGradientBoosting model
Args:
data: DataFrame containing sequences
Returns:
features: Feature matrix
labels: Label values (if data contains label column)
sequences: Processed sequence list
"""
features = []
labels = []
sequences = []
for _, row in data.iterrows():
sequence = row['full_seq']
# Get MFE values
mfe_values = {}
if 'mfe_40bp' in row:
mfe_values['mfe_40bp'] = row['mfe_40bp'] if pd.notna(row['mfe_40bp']) else 0.0
if 'mfe_120bp' in row:
mfe_values['mfe_120bp'] = row['mfe_120bp'] if pd.notna(row['mfe_120bp']) else 0.0
# Convert to features
feature_vector = sequence_to_features(sequence, self.gb_seq_length, mfe_values)
features.append(feature_vector)
sequences.append(sequence)
if 'label' in row:
labels.append(row['label'])
features = np.array(features)
labels = np.array(labels) if len(labels) > 0 else None
return features, labels, sequences
def prepare_bilstm_data(self, data):
"""
Prepare feature data for BiLSTM-CNN model
Args:
data: DataFrame containing sequences
Returns:
encoded_sequences: Encoded feature tensor
original_sequences: Original sequence list
"""
sequences = []
for _, row in data.iterrows():
sequences.append(row['full_seq'])
# Encode sequences
encoded = [encode_sequence(process_sequence(seq, self.bilstm_seq_length), self.bilstm_seq_length) for seq in sequences]
encoded = np.array(encoded)
return encoded, sequences
def gb_permutation_importance(self, n_repeats=10, use_positive_only=True):
"""
Permutation importance analysis for HistGradientBoosting model
Args:
n_repeats: Number of repeats
use_positive_only: Whether to use only positive samples
Returns:
Feature importance DataFrame
"""
try:
data = self.fs_samples if use_positive_only else self.validation_data
X, y, _ = self.prepare_gb_data(data)
if self.gb_model is None or X is None or y is None:
return None
result = permutation_importance(self.gb_model, X, y, n_repeats=n_repeats, random_state=42)
feature_names = get_feature_names(self.gb_seq_length)
importance_df = pd.DataFrame({
'feature': feature_names,
'importance': result.importances_mean,
'std': result.importances_std
}).sort_values('importance', ascending=False)
# Save results
output_path = os.path.join(self.gb_dir, 'permutation_importance.csv')
importance_df.to_csv(output_path, index=False)
self.gb_results['permutation_importance'] = importance_df
return importance_df
except Exception as e:
return None
def gb_shap_analysis(self, max_samples=100, use_positive_only=True):
"""
SHAP analysis for HistGradientBoosting model
Args:
max_samples: Maximum number of samples
use_positive_only: Whether to use only positive samples
Returns:
SHAP values DataFrame
"""
try:
data = self.fs_samples if use_positive_only else self.validation_data
data = data.sample(n=min(max_samples, len(data)), random_state=42) if len(data) > max_samples else data
X, y, _ = self.prepare_gb_data(data)
if self.gb_model is None or X is None:
return None
explainer = shap.Explainer(self.gb_model.predict, X)
shap_values = explainer(X)
feature_names = get_feature_names(self.gb_seq_length)
importance = np.abs(shap_values.values).mean(axis=0)
shap_df = pd.DataFrame({
'feature': feature_names,
'shap_importance': importance
}).sort_values('shap_importance', ascending=False)
# Save results
output_path = os.path.join(self.gb_dir, 'shap_importance.csv')
shap_df.to_csv(output_path, index=False)
self.gb_results['shap_importance'] = shap_df
return shap_df
except Exception as e:
return None
def bilstm_integrated_gradients(self, n_samples=20, steps=50, use_positive_only=True):
"""
Integrated gradients analysis for BiLSTM-CNN model
Args:
n_samples: Number of samples for analysis
steps: Number of integration steps
use_positive_only: Whether to use only positive samples
Returns:
Average importance scores DataFrame
"""
try:
data = self.fs_samples if use_positive_only else self.validation_data
data = data.sample(n=min(n_samples, len(data)), random_state=42) if len(data) > n_samples else data
X, sequences = self.prepare_bilstm_data(data)
if self.bilstm_model is None or X is None or len(X) == 0:
return None
# Create baseline (all zeros)
baseline = np.zeros_like(X, dtype=np.float32)
X = X.astype(np.float32)
# Compute integrated gradients
ig_values = self.integrated_gradients(self.bilstm_model, X, baseline, steps=steps)
# Aggregate across sequences and positions
avg_importance = ig_values.mean(axis=0)
positions = list(range(1, self.bilstm_seq_length + 1))
importance_df = pd.DataFrame({
'position': positions,
'importance': avg_importance
})
# Save results
output_path = os.path.join(self.bilstm_dir, 'integrated_gradients.csv')
importance_df.to_csv(output_path, index=False)
self.bilstm_results['integrated_gradients'] = importance_df
return importance_df
except Exception as e:
return None
def integrated_gradients(self, model, inputs, baseline, steps=50):
"""
Calculate integrated gradients
Args:
model: Model
inputs: Input samples
baseline: Baseline input
steps: Integration steps
Returns:
Integrated gradient values
"""
# Generate interpolated inputs
alphas = np.linspace(0.0, 1.0, steps + 1)
gradients = []
for alpha in alphas:
interpolated = baseline + alpha * (inputs - baseline)
with tf.GradientTape() as tape:
tape.watch(interpolated)
predictions = model(interpolated)
grads = tape.gradient(predictions, interpolated)
gradients.append(grads.numpy())
# Calculate integrated gradients
gradients = np.array(gradients)
integrated_grads = np.mean(gradients, axis=0) * (inputs - baseline)
return integrated_grads
2025-08-17 16:17:26 +08:00
def bilstm_saliency_map(self, n_samples=20, use_positive_only=True):
"""
Saliency map analysis for BiLSTM-CNN model
Args:
n_samples: Number of samples for analysis
use_positive_only: Whether to use only positive samples
Returns:
Average importance scores DataFrame
"""
try:
data = self.fs_samples if use_positive_only else self.validation_data
data = data.sample(n=min(n_samples, len(data)), random_state=42) if len(data) > n_samples else data
X, sequences = self.prepare_bilstm_data(data)
if self.bilstm_model is None or X is None or len(X) == 0:
return None
saliency_maps = []
for i in range(len(X)):
saliency_map = self._calculate_saliency(self.bilstm_model, X[i:i+1])
saliency_maps.append(saliency_map)
avg_saliency = np.mean(saliency_maps, axis=0)
positions = list(range(1, self.bilstm_seq_length + 1))
importance_df = pd.DataFrame({
'position': positions,
'saliency': avg_saliency
})
output_path = os.path.join(self.bilstm_dir, 'saliency_map.csv')
importance_df.to_csv(output_path, index=False)
self.bilstm_results['saliency_map'] = importance_df
return importance_df
except Exception as e:
return None
def _calculate_saliency(self, model, input_tensor):
"""Helper function to calculate saliency map for a single input."""
input_tensor = tf.convert_to_tensor(input_tensor, dtype=tf.float32)
with tf.GradientTape() as tape:
tape.watch(input_tensor)
predictions = model(input_tensor)
target_class_output = predictions[:, 0] # Assuming binary classification, target positive class
gradients = tape.gradient(target_class_output, input_tensor)
saliency = np.abs(gradients).sum(axis=-1)[0]
if np.max(saliency) > 0:
saliency = saliency / np.max(saliency)
return saliency
2025-08-17 15:32:55 +08:00
def run_all_analyses(self, max_samples=20, use_positive_only=True):
"""
Run all feature importance analysis methods
Args:
max_samples: Maximum number of samples
use_positive_only: Whether to use only positive samples
"""
results = {}
# GB model analyses
if self.gb_model is not None:
results['gb_permutation'] = self.gb_permutation_importance(use_positive_only=use_positive_only)
results['gb_shap'] = self.gb_shap_analysis(max_samples, use_positive_only)
# BiLSTM model analyses
if self.bilstm_model is not None:
results['bilstm_integrated_gradients'] = self.bilstm_integrated_gradients(
2025-08-17 16:17:26 +08:00
n_samples=max_samples, use_positive_only=use_positive_only
)
results['bilstm_saliency_map'] = self.bilstm_saliency_map(
n_samples=max_samples, use_positive_only=use_positive_only
2025-08-17 15:32:55 +08:00
)
# Save combined results summary
self._save_analysis_summary(results)
return results
def _save_analysis_summary(self, results):
"""Save analysis summary"""
try:
summary = {}
for method, result in results.items():
if result is not None:
summary[method] = {
'status': 'completed',
'num_features': len(result) if hasattr(result, '__len__') else 'N/A'
}
else:
summary[method] = {'status': 'failed'}
# Save summary
summary_path = os.path.join(self.combined_dir, 'analysis_summary.json')
import json
with open(summary_path, 'w') as f:
json.dump(summary, f, indent=2)
except Exception as e:
pass
def main():
"""Main analysis function"""
try:
# Model paths (virtual paths)
gb_model_path = os.path.join(BaseConfig.GB_MODEL_DIR, "best_model.joblib")
bilstm_model_path = os.path.join(BaseConfig.BILSTM_MODEL_DIR, "best_model.h5")
output_dir = os.path.join(BaseConfig.RESULT_DIR, "feature_importance")
# Create analyzer
analyzer = FeatureImportanceAnalyzer(
gb_model_path,
bilstm_model_path,
output_dir
)
# Run all analyses
analyzer.run_all_analyses(max_samples=20, use_positive_only=True)
except Exception as e:
pass
if __name__ == "__main__":
main()