FScanpy-package/FScanpy/__init__.py

205 lines
7.9 KiB
Python

from .predictor import PRFPredictor
from . import data
import pandas as pd
import numpy as np
from typing import Union, List, Dict
__version__ = '0.3.0'
__author__ = ''
__email__ = ''
__all__ = ['PRFPredictor', 'predict_prf', 'plot_prf_prediction', 'data', '__version__', '__author__', '__email__']
def predict_prf(
sequence: Union[str, List[str], None] = None,
data: Union[pd.DataFrame, None] = None,
window_size: int = 3,
short_threshold: float = 0.1,
ensemble_weight: float = 0.4,
model_dir: str = None
) -> pd.DataFrame:
"""
PRF site prediction function
Args:
sequence: Single or multiple DNA sequences for sliding window prediction
data: DataFrame data, must contain 'Long_Sequence' or '399bp' column for region prediction
window_size: Sliding window size (default: 3)
short_threshold: Short model (HistGB) probability threshold (default: 0.1)
ensemble_weight: Weight of short model in ensemble (default: 0.4, long weight: 0.6)
model_dir: Model directory path (optional)
Returns:
pandas.DataFrame: Prediction results containing the following main fields:
- Short_Probability: Short model prediction probability
- Long_Probability: Long model prediction probability
- Ensemble_Probability: Ensemble prediction probability (main result)
- Ensemble_Weights: Weight configuration information
Examples:
# 1. Single sequence sliding window prediction
>>> from FScanpy import predict_prf
>>> sequence = "ATGCGTACGT..."
>>> results = predict_prf(sequence=sequence)
# 2. Multiple sequences sliding window prediction
>>> sequences = ["ATGCGTACGT...", "GCTATAGCAT..."]
>>> results = predict_prf(sequence=sequences)
# 3. Custom ensemble weight ratio
>>> results = predict_prf(sequence=sequence, ensemble_weight=0.3) # 3:7 ratio
# 4. DataFrame region prediction
>>> import pandas as pd
>>> data = pd.DataFrame({
... 'Long_Sequence': ['ATGCGT...', 'GCTATAG...'] # or use '399bp'
... })
>>> results = predict_prf(data=data)
"""
predictor = PRFPredictor(model_dir=model_dir)
# Validate input parameters
if sequence is None and data is None:
raise ValueError("Must provide either sequence or data parameter")
if sequence is not None and data is not None:
raise ValueError("Cannot provide both sequence and data parameters")
if not (0.0 <= ensemble_weight <= 1.0):
raise ValueError("ensemble_weight must be between 0.0 and 1.0")
# Sliding window prediction mode
if sequence is not None:
if isinstance(sequence, str):
# Single sequence prediction
return predictor.predict_sequence(
sequence, window_size, short_threshold, ensemble_weight)
elif isinstance(sequence, (list, tuple)):
# Multiple sequences prediction
results = []
for i, seq in enumerate(sequence, 1):
try:
result = predictor.predict_sequence(
seq, window_size, short_threshold, ensemble_weight)
result['Sequence_ID'] = f'seq_{i}'
results.append(result)
except Exception as e:
print(f"Warning: Sequence {i} prediction failed - {str(e)}")
return pd.concat(results, ignore_index=True) if results else pd.DataFrame()
# Region prediction mode
else:
if not isinstance(data, pd.DataFrame):
raise ValueError("data parameter must be pandas DataFrame type")
# Check column names (support both new and old naming conventions)
seq_column = None
if 'Long_Sequence' in data.columns:
seq_column = 'Long_Sequence'
elif '399bp' in data.columns:
seq_column = '399bp'
else:
raise ValueError("DataFrame must contain 'Long_Sequence' or '399bp' column")
# Call region prediction function
try:
results = predictor.predict_regions(
data[seq_column], short_threshold, ensemble_weight)
# Add other columns from original data
for col in data.columns:
if col not in ['Long_Sequence', '399bp', 'Short_Sequence', '33bp']:
results[col] = data[col].values
return results
except Exception as e:
print(f"Warning: Region prediction failed - {str(e)}")
# Create empty results
long_weight = 1.0 - ensemble_weight
results = pd.DataFrame({
'Short_Probability': [0.0] * len(data),
'Long_Probability': [0.0] * len(data),
'Ensemble_Probability': [0.0] * len(data),
'Ensemble_Weights': [f'Short:{ensemble_weight:.1f}, Long:{long_weight:.1f}'] * len(data)
})
# Add original data columns
for col in data.columns:
results[col] = data[col].values
return results
def plot_prf_prediction(
sequence: str,
window_size: int = 3,
short_threshold: float = 0.65,
long_threshold: float = 0.8,
ensemble_weight: float = 0.4,
title: str = None,
save_path: str = None,
figsize: tuple = (12, 8),
dpi: int = 300,
model_dir: str = None
) -> tuple:
"""
Plot PRF prediction results for sequence frameshifting probability
Args:
sequence: Input DNA sequence
window_size: Sliding window size (default: 3)
short_threshold: Short model (HistGB) filtering threshold (default: 0.65)
long_threshold: Long model (BiLSTM-CNN) filtering threshold (default: 0.8)
ensemble_weight: Weight of short model in ensemble (default: 0.4, long weight: 0.6)
title: Plot title (optional)
save_path: Save path (optional, saves plot if provided)
figsize: Figure size (default: (12, 8))
dpi: Figure resolution (default: 300)
model_dir: Model directory path (optional)
Returns:
tuple: (pd.DataFrame, matplotlib.figure.Figure) prediction results and figure object
Examples:
# 1. Simple plotting
>>> from FScanpy import plot_prf_prediction
>>> sequence = "ATGCGTACGT..."
>>> results, fig = plot_prf_prediction(sequence)
>>> plt.show()
# 2. Custom thresholds and ensemble weights
>>> results, fig = plot_prf_prediction(
... sequence,
... short_threshold=0.7,
... long_threshold=0.85,
... ensemble_weight=0.3, # 3:7 weight ratio
... title="Custom Weight Prediction Results",
... save_path="prediction_result.png"
... )
# 3. Equal weight combination
>>> results, fig = plot_prf_prediction(
... sequence,
... ensemble_weight=0.5 # 5:5 equal weights
... )
# 4. Long model dominated
>>> results, fig = plot_prf_prediction(
... sequence,
... ensemble_weight=0.2 # 2:8 weights, long model dominated
... )
"""
if not (0.0 <= ensemble_weight <= 1.0):
raise ValueError("ensemble_weight must be between 0.0 and 1.0")
predictor = PRFPredictor(model_dir=model_dir)
return predictor.plot_sequence_prediction(
sequence=sequence,
window_size=window_size,
short_threshold=short_threshold,
long_threshold=long_threshold,
ensemble_weight=ensemble_weight,
title=title,
save_path=save_path,
figsize=figsize,
dpi=dpi
)