Welcome to CellDISECT’s documentation!
Note
Beta Version Available: A beta version (0.2.0b1) with compatibility for Google Colab and newer versions of torch and scvi-tools is available on the beta-colab branch. Install it with pip install celldisect==0.2.0b1.
CellDISECT (Cell DISentangled Experts for Covariate counTerfactuals) is a causal generative model designed to disentangle known covariate variations from unknown ones at test time while simultaneously learning to make counterfactual predictions.
Contents
Installation
Prerequisites
We recommend using Anaconda/Miniconda to create a conda environment for using CellDISECT.
Create and activate a conda environment:
conda create -n CellDISECT python=3.9
conda activate CellDISECT
Install PyTorch (tested with pytorch 2.1.2 and cuda 12):
conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=12.1 -c pytorch -c nvidia
Install CellDISECT:
You can install the stable version using pip:
pip install celldisect
Or install the latest development version from GitHub:
pip install git+https://github.com/Lotfollahi-lab/CellDISECT
Optional Dependencies
For RAPIDS/rapids-singlecell support:
pip install \
--extra-index-url=https://pypi.nvidia.com \
cudf-cu12==24.4.* dask-cudf-cu12==24.4.* cuml-cu12==24.4.* \
cugraph-cu12==24.4.* cuspatial-cu12==24.4.* cuproj-cu12==24.4.* \
cuxfilter-cu12==24.4.* cucim-cu12==24.4.* pylibraft-cu12==24.4.* \
raft-dask-cu12==24.4.* cuvs-cu12==24.4.*
pip install rapids-singlecell
For CUDA-enabled JAX:
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Quick Start
Here’s a simple example to get you started:
from celldisect import CellDISECT
import scanpy as sc
# Load your data
adata = sc.read_h5ad('your_data.h5ad')
adata.X = adata.layers['counts'].copy()
cats = ['cov1', 'cov2']
cell_type_included = False
# Initialize and train the model
CellDISECT.setup_anndata(
adata,
layer='counts',
categorical_covariate_keys=cats,
continuous_covariate_keys=[],
add_cluster_covariate=not cell_type_included, # add_cluster_covariate if cell type is not included
)
model = CellDISECT(adata)
model.train()
# Make predictions
predictions = model.predict_counterfactuals(
adata,
cov_names=['cov1'],
cov_values=['val1'],
cov_values_cf=['val2'],
cats=cats,
)