新增结构描述
This commit is contained in:
parent
8a072a35bf
commit
aa73f01179
|
|
@ -0,0 +1,146 @@
|
|||
# FScanpy Commit Code Architecture
|
||||
|
||||
## Code Manifest
|
||||
- `data_pre/`
|
||||
- `mfe.py`
|
||||
- `seaphage_knn.py`
|
||||
- `model_feature/`
|
||||
- `feature_analysis.py`
|
||||
- `train_models/`
|
||||
- `bilstm_cnn.py`
|
||||
- `hist_gb.py`
|
||||
- `utils/`
|
||||
- `config.py`
|
||||
- `function.py`
|
||||
|
||||
## Overview & Pipeline
|
||||
|
||||
This repository is organized around the main pipeline: "Read and preprocess sequence data → Construct features and train models → Analyze importance and save results."
|
||||
|
||||
1. **Data Preprocessing (`data_pre/`)**
|
||||
- Calculate and write back structural energy features (MFE), and filter samples based on external confidence scores and KNN distance.
|
||||
2. **Model Training (`train_models/`)**
|
||||
- The traditional Gradient Boosting model (HistGradientBoosting) uses explicit features (mono/di/tri-nucleotides + MFE).
|
||||
- The deep learning model (BiLSTM-CNN) learns representations end-to-end from encoded sequences and supports iterative self-training.
|
||||
3. **Feature Importance Analysis (`model_feature/`)**
|
||||
- For the GB model, perform Permutation Importance and SHAP analysis.
|
||||
- For the BiLSTM-CNN model, perform Integrated Gradients and Saliency Map analysis.
|
||||
4. **Common Configuration and Utilities (`utils/`)**
|
||||
- Contains general-purpose functions for path/directory configuration, data loading, evaluation, and results saving.
|
||||
|
||||
---
|
||||
|
||||
## `data_pre/`
|
||||
|
||||
### `data_pre/mfe.py`
|
||||
* **Purpose**:
|
||||
- Calls ViennaRNA (`import RNA`) to calculate the Minimum Free Energy (MFE) for given sequence windows and writes the results back to a CSV file.
|
||||
* **Logic**:
|
||||
- For each `full_seq`, it crops subsequences based on configuration (default start: 198, lengths: 40 and 120).
|
||||
- Uses `RNA.fold_compound(sub_seq).mfe()` to calculate MFE and populates the specified columns (`mfe_40bp`, `mfe_120bp`).
|
||||
* **Key Functions**:
|
||||
- `predict_rna_structure_multiple(csv_file, predictions)`: Writes back columns for multiple window configurations.
|
||||
- `calculate_mfe_features(data_file, output_file=None)`: The standard entry point that assembles predictions and calls the above function.
|
||||
* **Input/Output**:
|
||||
- **Input**: A CSV file containing `full_seq` (e.g., `BaseConfig.VALIDATION_DATA`).
|
||||
- **Output**: Appends two columns, `mfe_40bp` and `mfe_120bp`, to the original CSV and saves it.
|
||||
* **Dependencies**:
|
||||
- ViennaRNA Python interface (`RNA`).
|
||||
- `pandas`.
|
||||
|
||||
### `data_pre/seaphage_knn.py`
|
||||
* **Purpose**:
|
||||
- Selects medium/low-confidence neighbor samples based on externally provided `SEAPHAGES` subset confidence rankings (`final_rank` → `confidence`) and the one-hot encoded distance of sequences. These are then merged with high-confidence samples to create an augmented training set.
|
||||
* **Logic**:
|
||||
- Reads `BaseConfig.TRAIN_DATA`, `BaseConfig.TEST_DATA`, and `BaseConfig.SEAPHAGE_PROB` (requires configuration in `BaseConfig`).
|
||||
- Extracts samples from `SEAPHAGES` with `label==1`, aligns them with the probability table using the `DNA_seqid` prefix to generate `confidence`.
|
||||
- Converts `FS_period` sequences to one-hot encoding, standardizes them, and uses high-confidence samples as a reference library. It then calculates the average KNN distance for medium/low-confidence samples and filters them based on a quantile threshold.
|
||||
- Merges and annotates with `confidence_level` (high/medium/low), then saves to a specified CSV.
|
||||
* **Input/Output**:
|
||||
- **Input**: Training/testing CSVs and the `seaphage_prob` probability table.
|
||||
- **Output**: The filtered `seaphage_selected.csv` (Note: the output path is currently hardcoded and should be managed by `BaseConfig`).
|
||||
* **Dependencies**:
|
||||
- `pandas`, `numpy`, `scikit-learn` (`NearestNeighbors`, `StandardScaler`).
|
||||
|
||||
---
|
||||
|
||||
## `train_models/`
|
||||
|
||||
### `train_models/hist_gb.py`
|
||||
* **Purpose**:
|
||||
- Uses `HistGradientBoostingClassifier` for frameshift site classification, training, and evaluation. It supports constructing explicit features from `full_seq` and adding MFE features.
|
||||
* **Logic and Features**:
|
||||
- A central crop of `GBConfig.SEQUENCE_LENGTH=33` is used to maintain the reading frame, constructing:
|
||||
- Mononucleotide one-hot features (4×L).
|
||||
- Dinucleotide features (16×(L-1)).
|
||||
- Trinucleotide (codon) features (64×(L-2)).
|
||||
- Structural energy features: `mfe_40bp`, `mfe_120bp`.
|
||||
* **Key Functions**:
|
||||
- `sequence_to_features()`, `prepare_data()`: Generate the feature matrix and weights from `full_seq` and MFE columns.
|
||||
- `train_hist_model()`: Handles training, validation split, and evaluation (on test set and external `Xu/Atkins` sets).
|
||||
- `analyze_feature_importance()`: Exports built-in feature importance to a CSV file.
|
||||
* **Input/Output**:
|
||||
- **Input**: Merged data from `BaseConfig.DATA_DIR` (`merged_train_data.csv`, `merged_test_data.csv`, etc.).
|
||||
- **Output**: Model object, evaluation metrics, and an importance CSV (saved to `BaseConfig.GB_DIR`).
|
||||
|
||||
### `train_models/bilstm_cnn.py`
|
||||
* **Purpose**:
|
||||
- An end-to-end sequence classification model with a hybrid architecture. The model processes sequences through the following layers: `Input` → `Embedding` → `BiLSTM` → `Parallel CNN Branches` → `Concatenation` → `Dense Layers` → `Sigmoid Output`.
|
||||
- Supports self-training: iteratively selects pseudo-labeled samples from a pool of low-confidence samples to add to the training set.
|
||||
* **Logic**:
|
||||
- **Sequence Encoding**: `encode_sequence()` converts 'ATCG' to {1,2,3,4} and pads/trims to `Config.Sequence_len=399`.
|
||||
- **Training Monitoring**: `MetricsCallback` calculates test set metrics at each epoch to track the best-performing model.
|
||||
- **Self-training Loop**: Calls `utils.function.select_low_confidence_samples_cnn()` to select samples based on model probability and a given `final_prob`.
|
||||
* **Key Functions**:
|
||||
- `create_bilstm_cnn_model()`, `prepare_data()`, `train_bilstm_cnn_model()`.
|
||||
- `main()`: Loads data, trains the model, and saves the best and final models along with training information using `save_training_info()`.
|
||||
* **Input/Output**:
|
||||
- **Input**: Train/test sets returned by `load_data()`, and optional external validation sets (Xu/Atkins).
|
||||
- **Output**: Saves `*.h5` model, `*_training_info.json`, and `*_weights.pkl` to `BaseConfig.BILSTM_MODEL_DIR`.
|
||||
|
||||
---
|
||||
|
||||
## `model_feature/`
|
||||
|
||||
### `model_feature/feature_analysis.py`
|
||||
* **Purpose**:
|
||||
- Provides a unified interface for feature importance analysis:
|
||||
- **GB Model**: Permutation Importance (`sklearn.inspection.permutation_importance`) and SHAP.
|
||||
- **BiLSTM-CNN Model**: Integrated Gradients and Saliency Maps (Gradient-based).
|
||||
* **Logic**:
|
||||
- Reads trained models and validation sets, encodes data according to each model's pipeline, calculates importance, and saves results to separate files and a summary JSON.
|
||||
* **Key Classes/Methods**:
|
||||
- `FeatureImportanceAnalyzer`: A class that encapsulates model/data loading, feature preparation, and various importance methods.
|
||||
- `run_all_analyses()`: A single command to run all analyses and save results to `output_dir/{gb_model,bilstm_model,combined_analysis}`.
|
||||
* **Note**:
|
||||
- Import paths are currently written as `from models.hist_gb ...` and `from models.bilstm_cnn ...`, but the actual files are in `train_models/`. This inconsistency needs to be fixed before running, either by correcting the paths or creating a package with that name.
|
||||
|
||||
---
|
||||
|
||||
## `utils/`
|
||||
|
||||
### `utils/config.py`
|
||||
* **Purpose**:
|
||||
- Centralizes the management of paths for data, models, and results. Provides `create_directories()` to ensure that directories exist.
|
||||
* **Note**:
|
||||
- Currently contains placeholder paths (e.g., `/path/to/...`). These must be modified according to the environment before execution, including `DATA_DIR`, `TRAIN_DATA`, `RESULT_DIR`, etc.
|
||||
|
||||
### `utils/function.py`
|
||||
* **Purpose**:
|
||||
- **Common Utilities**:
|
||||
- **Self-training Sample Selection** (two sets for CNN/GB): Selects pseudo-labeled samples based on model probability, entropy, and an external `final_prob` confidence threshold.
|
||||
- **Save Training Results and Models**: `save_training_info()` saves the `.h5` model, a training info JSON, and weights pkl simultaneously.
|
||||
- **Data Loading**: `load_data()` merges and validates columns (`full_seq`, `label`, `source`) and downsamples `EUPLOTES` negative samples as needed.
|
||||
- **Evaluation**: `evaluate_model_gb()` and `evaluate_model_cnn()` calculate common metrics and logloss.
|
||||
|
||||
---
|
||||
|
||||
## Interactions & Flow
|
||||
1. Use `data_pre/mfe.py` to calculate and write MFE columns to the data CSVs.
|
||||
2. Use `data_pre/seaphage_knn.py` to filter and supplement training samples based on confidence and KNN.
|
||||
3. **Training**:
|
||||
- **GB (`train_models/hist_gb.py`)**: Construct explicit features from `full_seq` + `mfe_*` for training and evaluation.
|
||||
- **BiLSTM-CNN (`train_models/bilstm_cnn.py`)**: End-to-end training on encoded sequences using a hybrid BiLSTM-CNN architecture, with support for iterative self-training.
|
||||
4. **Analysis**: `model_feature/feature_analysis.py` outputs feature/positional importance and a summary from various methods.
|
||||
|
||||
---
|
||||
|
|
@ -138,40 +138,6 @@ class FeatureImportanceAnalyzer:
|
|||
|
||||
return encoded, sequences
|
||||
|
||||
def gb_built_in_importance(self, use_positive_only=True):
|
||||
"""
|
||||
Feature importance analysis using HistGradientBoostingClassifier built-in importance
|
||||
|
||||
Returns:
|
||||
importance_df: Feature importance DataFrame
|
||||
"""
|
||||
try:
|
||||
data = self.fs_samples if use_positive_only else self.validation_data
|
||||
X, y, _ = self.prepare_gb_data(data)
|
||||
|
||||
if self.gb_model is None or X is None or len(X) == 0:
|
||||
return None
|
||||
|
||||
# Built-in feature importance
|
||||
if hasattr(self.gb_model, 'feature_importances_'):
|
||||
feature_names = get_feature_names(self.gb_seq_length)
|
||||
importance_scores = self.gb_model.feature_importances_
|
||||
|
||||
importance_df = pd.DataFrame({
|
||||
'feature': feature_names,
|
||||
'importance': importance_scores
|
||||
}).sort_values('importance', ascending=False)
|
||||
|
||||
# Save results
|
||||
output_path = os.path.join(self.gb_dir, 'built_in_importance.csv')
|
||||
importance_df.to_csv(output_path, index=False)
|
||||
|
||||
self.gb_results['built_in_importance'] = importance_df
|
||||
return importance_df
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
def gb_permutation_importance(self, n_repeats=10, use_positive_only=True):
|
||||
"""
|
||||
|
|
@ -327,6 +293,63 @@ class FeatureImportanceAnalyzer:
|
|||
|
||||
return integrated_grads
|
||||
|
||||
def bilstm_saliency_map(self, n_samples=20, use_positive_only=True):
|
||||
"""
|
||||
Saliency map analysis for BiLSTM-CNN model
|
||||
|
||||
Args:
|
||||
n_samples: Number of samples for analysis
|
||||
use_positive_only: Whether to use only positive samples
|
||||
|
||||
Returns:
|
||||
Average importance scores DataFrame
|
||||
"""
|
||||
try:
|
||||
data = self.fs_samples if use_positive_only else self.validation_data
|
||||
data = data.sample(n=min(n_samples, len(data)), random_state=42) if len(data) > n_samples else data
|
||||
X, sequences = self.prepare_bilstm_data(data)
|
||||
|
||||
if self.bilstm_model is None or X is None or len(X) == 0:
|
||||
return None
|
||||
|
||||
saliency_maps = []
|
||||
for i in range(len(X)):
|
||||
saliency_map = self._calculate_saliency(self.bilstm_model, X[i:i+1])
|
||||
saliency_maps.append(saliency_map)
|
||||
|
||||
avg_saliency = np.mean(saliency_maps, axis=0)
|
||||
positions = list(range(1, self.bilstm_seq_length + 1))
|
||||
|
||||
importance_df = pd.DataFrame({
|
||||
'position': positions,
|
||||
'saliency': avg_saliency
|
||||
})
|
||||
|
||||
output_path = os.path.join(self.bilstm_dir, 'saliency_map.csv')
|
||||
importance_df.to_csv(output_path, index=False)
|
||||
|
||||
self.bilstm_results['saliency_map'] = importance_df
|
||||
return importance_df
|
||||
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
def _calculate_saliency(self, model, input_tensor):
|
||||
"""Helper function to calculate saliency map for a single input."""
|
||||
input_tensor = tf.convert_to_tensor(input_tensor, dtype=tf.float32)
|
||||
with tf.GradientTape() as tape:
|
||||
tape.watch(input_tensor)
|
||||
predictions = model(input_tensor)
|
||||
target_class_output = predictions[:, 0] # Assuming binary classification, target positive class
|
||||
|
||||
gradients = tape.gradient(target_class_output, input_tensor)
|
||||
saliency = np.abs(gradients).sum(axis=-1)[0]
|
||||
|
||||
if np.max(saliency) > 0:
|
||||
saliency = saliency / np.max(saliency)
|
||||
|
||||
return saliency
|
||||
|
||||
def run_all_analyses(self, max_samples=20, use_positive_only=True):
|
||||
"""
|
||||
Run all feature importance analysis methods
|
||||
|
|
@ -339,14 +362,16 @@ class FeatureImportanceAnalyzer:
|
|||
|
||||
# GB model analyses
|
||||
if self.gb_model is not None:
|
||||
results['gb_built_in'] = self.gb_built_in_importance(use_positive_only)
|
||||
results['gb_permutation'] = self.gb_permutation_importance(use_positive_only=use_positive_only)
|
||||
results['gb_shap'] = self.gb_shap_analysis(max_samples, use_positive_only)
|
||||
|
||||
# BiLSTM model analyses
|
||||
if self.bilstm_model is not None:
|
||||
results['bilstm_integrated_gradients'] = self.bilstm_integrated_gradients(
|
||||
max_samples, use_positive_only=use_positive_only
|
||||
n_samples=max_samples, use_positive_only=use_positive_only
|
||||
)
|
||||
results['bilstm_saliency_map'] = self.bilstm_saliency_map(
|
||||
n_samples=max_samples, use_positive_only=use_positive_only
|
||||
)
|
||||
|
||||
# Save combined results summary
|
||||
|
|
|
|||
Loading…
Reference in New Issue