Example Scripts

CellDISECT comes with example scripts that demonstrate the complete workflow for training, inference, and perturbation prediction. These scripts are located in the examples/ directory of the repository:

Training Example

The training_example.py script demonstrates how to train a CellDISECT model with customizable architecture and training parameters. Here’s a breakdown of the key components:

import scvi
import scanpy as sc
import torch
from lightning.pytorch.loggers import WandbLogger
from celldisect import CellDISECT

# Load and prepare data
adata = sc.read_h5ad('PATH/TO/DATA.h5ad')
adata = adata[adata.X.sum(1) != 0].copy()

# Define covariates
cats = [
    'cat1',
    'cat2',
    'cat3',
]
cell_type_included = False  # Set to True if cell type annotation is in cats

Key Configuration Options:

  1. Architecture Parameters:

arch_dict = {
    'n_layers': 2,
    'n_hidden': 128,
    'n_latent_shared': 32,
    'n_latent_attribute': 32,
    'dropout_rate': 0.1,
    'weighted_classifier': False,
}
  1. Training Parameters:

train_dict = {
    'max_epochs': 1000,
    'batch_size': 256,
    'recon_weight': 20,
    'cf_weight': 0.8,
    'beta': 0.003,
    'clf_weight': 0.05,
    'adv_clf_weight': 0.014,
    'adv_period': 5,
    'n_cf': 1,
    'early_stopping_patience': 6,
    'early_stopping': True,
    'save_best': True,
}
  1. Training Plan Parameters:

plan_kwargs = {
    'lr': 0.003,
    'weight_decay': 0.00005,
    'ensemble_method_cf': True,
    'lr_patience': 5,
    'lr_factor': 0.5,
    'lr_scheduler_metric': 'loss_validation',
    'n_epochs_kl_warmup': 10,
}

The script includes optional Weights & Biases (wandb) integration for training monitoring.

Inference Example

The inference_example.py script shows how to load a trained model and perform various types of inference. Key features include:

  1. Loading a trained model:

model = CellDISECT.load(f"{pre_path}/{model_name}", adata=adata)
  1. Extracting different latent representations:

  • Z_0 (shared latent space)

  • Z_i (covariate-specific latent spaces)

  • Z_{-i} (complement latent spaces)

  • Z_{0+Z_i} (combined latent spaces)

  1. Computing neighbors and UMAP visualizations for all latent representations:

# Compute neighbors and UMAPs for each latent space
for i in range(len(cats) + 1):
    if i == 0:
        latent_name = f"CellDISECT_Z_{i}"
    else:
        label = cats[i - 1]
        latent_name = f"CellDISECT_Z_{label}"

    latent = ad.AnnData(X=adata.obsm[f"{latent_name}"], obs=adata.obs)
    sc.pp.neighbors(adata=latent, use_rep="X")
    sc.tl.umap(adata=latent)

The script also includes commented-out plotting code that you can use in your analysis notebooks.

Using the Examples

  1. Copy the relevant example script to your working directory

  2. Modify the paths and parameters according to your needs:

    • PATH/TO/DATA.h5ad: Path to your input data

    • PATH/TO/SAVE/YOUR/MODEL: Where to save the trained model

    • cats: List of categorical covariates from your data

    • Architecture and training parameters as needed

  3. For inference, make sure to:

    • Use the same covariate list as during training

    • Specify the correct path to your trained model

    • Adjust the output path for saving results

Perturbation Example

The perturbation_example.py script demonstrates how to use CellDISECT for perturbation prediction. It covers:

  1. Preparing predefined embeddings:

# Load gene embeddings (e.g. GenePT, ESM)
gene_embeddings = np.load('PATH/TO/GENE_EMBEDDINGS.npy', allow_pickle=True).item()
adata.uns['pert_embeddings'] = gene_embeddings
  1. Setting up with perturbation support:

CellDISECT.setup_anndata(
    adata,
    layer='counts',
    categorical_covariate_keys=['cell_type', 'perturbation'],
    perturbation_key='perturbation',
    perturbation_embedding_key='pert_embeddings',
    perturbation_combination_delimiter='+',
)
  1. Predicting perturbations (seen, unseen, or combinatorial):

x_ctrl, x_true, x_pred = model.predict_perturbation(
    adata,
    perturbation='GeneA+GeneB',
    source_perturbation='ctrl',
    cats=['cell_type', 'perturbation'],
    perturbation_key='perturbation',
)
  1. Evaluating predictions:

from celldisect import perturbation_metrics
metrics = perturbation_metrics(x_pred.numpy(), x_true.numpy(), x_ctrl.numpy())

These scripts serve as comprehensive templates for working with CellDISECT and can be adapted to your specific use case.