Example Scripts
CellDISECT comes with example scripts that demonstrate the complete workflow for training, inference, and perturbation prediction. These scripts are located in the examples/ directory of the repository:
Training Example
The training_example.py script demonstrates how to train a CellDISECT model with customizable architecture and training parameters. Here’s a breakdown of the key components:
import scvi
import scanpy as sc
import torch
from lightning.pytorch.loggers import WandbLogger
from celldisect import CellDISECT
# Load and prepare data
adata = sc.read_h5ad('PATH/TO/DATA.h5ad')
adata = adata[adata.X.sum(1) != 0].copy()
# Define covariates
cats = [
'cat1',
'cat2',
'cat3',
]
cell_type_included = False # Set to True if cell type annotation is in cats
Key Configuration Options:
Architecture Parameters:
arch_dict = {
'n_layers': 2,
'n_hidden': 128,
'n_latent_shared': 32,
'n_latent_attribute': 32,
'dropout_rate': 0.1,
'weighted_classifier': False,
}
Training Parameters:
train_dict = {
'max_epochs': 1000,
'batch_size': 256,
'recon_weight': 20,
'cf_weight': 0.8,
'beta': 0.003,
'clf_weight': 0.05,
'adv_clf_weight': 0.014,
'adv_period': 5,
'n_cf': 1,
'early_stopping_patience': 6,
'early_stopping': True,
'save_best': True,
}
Training Plan Parameters:
plan_kwargs = {
'lr': 0.003,
'weight_decay': 0.00005,
'ensemble_method_cf': True,
'lr_patience': 5,
'lr_factor': 0.5,
'lr_scheduler_metric': 'loss_validation',
'n_epochs_kl_warmup': 10,
}
The script includes optional Weights & Biases (wandb) integration for training monitoring.
Inference Example
The inference_example.py script shows how to load a trained model and perform various types of inference. Key features include:
Loading a trained model:
model = CellDISECT.load(f"{pre_path}/{model_name}", adata=adata)
Extracting different latent representations:
Z_0 (shared latent space)
Z_i (covariate-specific latent spaces)
Z_{-i} (complement latent spaces)
Z_{0+Z_i} (combined latent spaces)
Computing neighbors and UMAP visualizations for all latent representations:
# Compute neighbors and UMAPs for each latent space
for i in range(len(cats) + 1):
if i == 0:
latent_name = f"CellDISECT_Z_{i}"
else:
label = cats[i - 1]
latent_name = f"CellDISECT_Z_{label}"
latent = ad.AnnData(X=adata.obsm[f"{latent_name}"], obs=adata.obs)
sc.pp.neighbors(adata=latent, use_rep="X")
sc.tl.umap(adata=latent)
The script also includes commented-out plotting code that you can use in your analysis notebooks.
Using the Examples
Copy the relevant example script to your working directory
Modify the paths and parameters according to your needs:
PATH/TO/DATA.h5ad: Path to your input dataPATH/TO/SAVE/YOUR/MODEL: Where to save the trained modelcats: List of categorical covariates from your dataArchitecture and training parameters as needed
For inference, make sure to:
Use the same covariate list as during training
Specify the correct path to your trained model
Adjust the output path for saving results
Perturbation Example
The perturbation_example.py script demonstrates how to use CellDISECT for perturbation prediction. It covers:
Preparing predefined embeddings:
# Load gene embeddings (e.g. GenePT, ESM)
gene_embeddings = np.load('PATH/TO/GENE_EMBEDDINGS.npy', allow_pickle=True).item()
adata.uns['pert_embeddings'] = gene_embeddings
Setting up with perturbation support:
CellDISECT.setup_anndata(
adata,
layer='counts',
categorical_covariate_keys=['cell_type', 'perturbation'],
perturbation_key='perturbation',
perturbation_embedding_key='pert_embeddings',
perturbation_combination_delimiter='+',
)
Predicting perturbations (seen, unseen, or combinatorial):
x_ctrl, x_true, x_pred = model.predict_perturbation(
adata,
perturbation='GeneA+GeneB',
source_perturbation='ctrl',
cats=['cell_type', 'perturbation'],
perturbation_key='perturbation',
)
Evaluating predictions:
from celldisect import perturbation_metrics
metrics = perturbation_metrics(x_pred.numpy(), x_true.numpy(), x_ctrl.numpy())
These scripts serve as comprehensive templates for working with CellDISECT and can be adapted to your specific use case.