celldisect.CellDISECT
- class celldisect.CellDISECT(*args: Any, **kwargs: Any)[source]
Bases:
RNASeqMixin,VAEMixin,UnsupervisedTrainingMixin,BaseModelClass,TunableMixinCellDISECT model for single-cell RNA sequencing data analysis.
- Parameters:
adata – AnnData object that has been registered via
setup_anndata().n_hidden – Number of nodes per hidden layer.
n_latent_shared – Dimensionality of the shared latent space.
n_latent_attribute – Dimensionality of the latent space for each sensitive attribute.
n_layers – Number of hidden layers used for encoder and decoder neural networks.
dropout_rate – Dropout rate for neural networks.
gene_likelihood –
One of:
'nb'- Negative binomial distribution'zinb'- Zero-inflated negative binomial distribution'poisson'- Poisson distribution
latent_distribution –
One of:
'normal'- Normal distribution'ln'- Logistic normal distribution (Normal(0, I) transformed by softmax)
split_key – Key in adata.obs to split the data into training, validation, and test sets.
train_split – Values in split_key to be used for training.
valid_split – Values in split_key to be used for validation.
test_split – Values in split_key to be used for testing.
weighted_classifier – Whether to use weighted classifiers for categorical covariates.
**model_kwargs – Additional keyword arguments for the model.
Methods
__init__(adata[, n_hidden, n_latent_shared, ...])Initialize the CellDISECT model.
add_cluster_covariate(adata[, normalize_counts])Run PCA on the gene expression matrix and run Leiden clustering on the PCA components to create a cluster covariate to be added to the adata.obs.
add_perturbation_embedding(name, embedding)Register an embedding vector for an unseen atomic perturbation.
Return current perturbation names and their (combined) embeddings.
setup_anndata(adata[, layer, batch_key, ...])Set up the AnnData object for the CellDISECT model.
train([max_epochs, use_gpu, train_size, ...])Train the model.
- classmethod setup_anndata(adata: anndata.AnnData, layer: str | None = None, batch_key: str | None = None, labels_key: str | None = None, size_factor_key: str | None = None, categorical_covariate_keys: List[str] | None = None, continuous_covariate_keys: List[str] | None = None, add_cluster_covariate: bool = False, clustering_normalize_counts: bool = True, perturbation_key: str | None = None, perturbation_embedding_key: str | None = None, perturbation_combination_delimiter: str = '+', **kwargs)
Set up the AnnData object for the CellDISECT model.
This method configures the AnnData object by registering the necessary fields and optionally adding a cluster covariate. When
perturbation_keyis provided, the corresponding column inadata.obsis treated as a perturbation covariate whose embeddings come fromadata.uns[perturbation_embedding_key]rather than being learned during training.- Parameters:
adata (AnnData) – AnnData object to be set up.
layer (Optional[str], optional) – Layer in adata to use as the count data, by default None.
batch_key (Optional[str], optional) – Key in adata.obs for batch information, by default None.
labels_key (Optional[str], optional) – Key in adata.obs for labels, by default None.
size_factor_key (Optional[str], optional) – Key in adata.obs for size factors, by default None.
categorical_covariate_keys (Optional[List[str]], optional) – List of keys in adata.obs for categorical covariates, by default None.
continuous_covariate_keys (Optional[List[str]], optional) – List of keys in adata.obs for continuous covariates, by default None.
add_cluster_covariate (bool, optional) – Whether to add a cluster covariate to adata.obs, by default False.
clustering_normalize_counts (bool, optional) – Whether to normalize counts before clustering, by default True.
perturbation_key (Optional[str], optional) – Column in
adata.obsthat contains perturbation labels (e.g."GeneA","GeneA+GeneB"). When set, the perturbation covariate uses predefined embeddings instead of learned ones.perturbation_embedding_key (Optional[str], optional) – Key in
adata.unswhose value is adict[str, np.ndarray]mapping atomic perturbation names to their vector representations (e.g. ESM or GenePT embeddings). Required whenperturbation_keyis set.perturbation_combination_delimiter (str, optional) – Delimiter for combinatorial perturbation labels, by default
"+".**kwargs – Additional keyword arguments.
- Return type:
None
- classmethod add_cluster_covariate(adata: anndata.AnnData, normalize_counts: bool = True)[source]
Run PCA on the gene expression matrix and run Leiden clustering on the PCA components to create a cluster covariate to be added to the adata.obs.
- Parameters:
adata (AnnData) – AnnData object containing the single-cell RNA sequencing data.
normalize_counts (bool, optional) – If True, takes the counts from the adata.layers[‘counts’] and log normalizes them, by default True.
- Return type:
None
- predict_counterfactuals(adata: anndata.AnnData, cov_names: list[str], cov_values: list[str], cov_values_cf: list[str], cats: list[str], n_samples_from_source: int | None = None, seed: int | None = 0)
Predicts counterfactuals for a given subset of data.
This function estimates the counterfactual outcomes for a subset of data based on specified changes in covariate values.
- Parameters:
adata (AnnData) – The subset of the data for which the counterfactuals are to be predicted.
cov_names (list[str]) – Names of the covariates that are to be changed.
cov_values (list[str]) – Original values for the covariates that are to be changed.
cov_values_cf (list[str]) – Counterfactual values for the covariates that are to be changed.
n_samples_from_source (Optional[int], optional) – Number of samples to take from the source data to predict the counterfactuals. If None, all samples from the source data are used. Defaults to None.
seed (Optional[int], optional) – Random seed for reproducibility. Defaults to 0. Only used if n_samples_from_source is not None.
- Returns:
Control (source), True Counterfactuals, and Predicted Counterfactual COUNTS (not log-transformed).
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
Examples
Single covariate change:
cats = ['cell_type', 'condition'] cell_type_to_check = ['CD4 T',] cov_names = ['condition'] cov_values = ['ctrl'] cov_values_cf = ['stimulated'] n_samples_from_source = 500 x_ctrl, x_true, x_pred = model.predict_counterfactuals( adata[(adata.obs['cell_type'].isin(cell_type_to_check))].copy(), cov_names=cov_names, cov_values=cov_values, cov_values_cf=cov_values_cf, cats=cats, n_samples_from_source=n_samples_from_source, )
Multiple covariate change:
cats = ['tissue', 'Sample ID', 'sex', 'Age_bin', 'CoarseCellType'] cell_type_to_check = 'Epithelial cell (luminal)' cov_names = ['sex', 'tissue'] cov_values = ['female', 'breast'] cov_values_cf = ['male', 'prostate gland'] n_samples_from_source = None x_ctrl, x_true, x_pred = model.predict_counterfactuals( adata[adata.obs['Broad cell type'] == cell_type_to_check].copy(), cov_names=cov_names, cov_values=cov_values, cov_values_cf=cov_values_cf, cats=cats, n_samples_from_source=n_samples_from_source, )
- get_latent_representation(adata: anndata.AnnData | None = None, indices: Sequence[int] | None = None, batch_size: int | None = None, nullify_cat_covs_indices: List[int] | None = None, nullify_shared: bool | None = False) numpy.ndarray | Tuple[numpy.ndarray, numpy.ndarray]
Get the latent representation of the data.
This function computes the latent representation of the data using the trained model. It allows for optional nullification of specific categorical covariates or the shared latent space.
- Parameters:
adata (Optional[AnnData]) – Annotated data object. If None, uses the data registered with the model.
indices (Optional[Sequence[int]]) – Optional indices to subset the data.
batch_size (Optional[int]) – Batch size to use for data loading. If None, uses the default batch size.
nullify_cat_covs_indices (Optional[List[int]]) – List of indices of categorical covariates to nullify in the latent space. If None, no covariates are nullified.
nullify_shared (Optional[bool]) – If True, nullifies the shared latent space. Defaults to False.
- Returns:
The latent representation of the data. If nullify_cat_covs_indices or nullify_shared is specified, returns a tuple of the latent representation and the nullified latent representation.
- Return type:
Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]
- get_cat_covariate_latents()
Returns the embeddings of the categorical covariates.
This function retrieves the embeddings of the categorical covariates from the trained model. It also returns the mappings of the categorical covariates.
- Parameters:
self (object) – The instance of the class.
- Returns:
A tuple containing two dictionaries: - covar_embeddings: Dictionary where keys are covariate names and values are the embeddings as numpy arrays. - covar_mappings: Dictionary where keys are covariate names and values are the mappings of the covariates.
- Return type:
- get_gene_importance() pandas.DataFrame
Computes gene importance for each encoder.
This method calculates a gene importance score for each gene based on the weights of the encoders. The importance is calculated by propagating the weights through the network layers to obtain an effective weight matrix from input genes to the latent space. This method is most accurate when the model is trained without biases and with a non-negative activation function like ReLU. If biases are present, they are ignored in this calculation.
- Returns:
A pandas DataFrame containing gene importance scores. The rows correspond to genes, and the columns correspond to the different encoders (e.g., ‘shared’, ‘attribute_0’, ‘attribute_1’, …).
- Return type:
pd.DataFrame
- train(max_epochs: int | None = None, use_gpu: str | int | bool | None = True, train_size: float = 0.8, validation_size: float | None = None, batch_size: int = 256, early_stopping: bool = True, save_best: bool = False, plan_kwargs: dict | None = None, recon_weight: scvi.autotune._types.Tunable.typing.Union[float, int] = 10, cf_weight: scvi.autotune._types.Tunable.typing.Union[float, int] = 1, beta: scvi.autotune._types.Tunable.typing.Union[float, int] = 1, clf_weight: scvi.autotune._types.Tunable.typing.Union[float, int] = 50, adv_clf_weight: scvi.autotune._types.Tunable.typing.Union[float, int] = 10, adv_period: scvi.autotune._types.Tunable.<class 'int'> = 1, n_cf: scvi.autotune._types.Tunable.<class 'int'> = 10, kappa_optimizer2: bool = True, n_epochs_pretrain_ae: int = 0, **trainer_kwargs)[source]
Train the model.
- Parameters:
max_epochs (Optional[int]) – Number of passes through the dataset. If None, defaults to np.min([round((20000 / n_cells) * 400), 400]).
use_gpu (Optional[Union[str, int, bool]]) – Whether to use GPU for training. Can be a boolean, string, or integer specifying the GPU device.
train_size (float) – Size of the training set in the range [0.0, 1.0].
validation_size (Optional[float]) – Size of the validation set. If None, defaults to 1 - train_size. If train_size + validation_size < 1, the remaining cells belong to a test set.
batch_size (int) – Minibatch size to use during training.
early_stopping (bool) – Perform early stopping. Additional arguments can be passed in **kwargs. See
Trainerfor further options.save_best (bool) – Save the best model state with respect to the validation loss (default), or use the final state in the training procedure.
plan_kwargs (Optional[dict]) – Keyword arguments for
TrainingPlan. Keyword arguments passed to train() will overwrite values present in plan_kwargs, when appropriate.recon_weight (Tunable[Union[float, int]]) – Weight for the reconstruction loss of X.
cf_weight (Tunable[Union[float, int]]) – Weight for the reconstruction loss of X_cf.
beta (Tunable[Union[float, int]]) – Weight for the KL divergence of Zi.
clf_weight (Tunable[Union[float, int]]) – Weight for the Si classifier loss.
adv_clf_weight (Tunable[Union[float, int]]) – Weight for the adversarial classifier loss.
adv_period (Tunable[int]) – Adversarial training period.
n_cf (Tunable[int]) – Number of X_cf reconstructions (a random permutation of n VAEs and a random half-batch subset for each trial).
kappa_optimizer2 (bool) – Whether to use the second kappa optimizer.
n_epochs_pretrain_ae (int) – Number of epochs to pretrain the autoencoder.
**trainer_kwargs – Other keyword arguments for
Trainer.
- Return type:
None
- generate(n_samples: int, covariates: dict[str, str], library_size: float | None = None, adata_for_library_size: anndata.AnnData | None = None, random_seed: int | None = None, include_shared_latent: bool = False)
Generate new cells conditionally based on specified categorical covariates.
This method generates n_samples new cells for a given combination of categorical covariates. The generation process involves sampling from the latent spaces and then decoding to get gene expression profiles.
The shared latent space (z_shared) is sampled from a standard normal distribution. The attribute-specific latent spaces (zs) are sampled from their respective prior encoders, conditioned on the provided covariates.
- Parameters:
n_samples – Number of cells to generate.
covariates – A dictionary mapping categorical covariate names to the desired values for generation. All categorical covariates used to train the model must be specified.
library_size – A specific library size to use for generation. Mutually exclusive with adata_for_library_size.
adata_for_library_size – An AnnData object to infer the median library size from. If cells with the specified covariates combination exist in this AnnData object, their median library size is used. Otherwise, the median library size of the entire adata_for_library_size is used. Mutually exclusive with library_size.
random_seed – A random seed for reproducibility.
include_shared_latent – Whether to include the expression from the decoder associated with the shared latent space (Z_0) in the average_expression. If False, the average is computed over attribute-specific decoders only. We suggest setting this to False for most use cases. The shared decoder is assumed to be decoder_0.
- Returns:
average_expression (numpy.ndarray): An array of shape (n_samples, n_vars) representing the average gene expression across the selected decoders.
all_expressions (dict[str, numpy.ndarray]): A dictionary where keys are decoder names (decoder_0, decoder_1, etc.) and values are numpy arrays of shape (n_samples, n_vars) representing the generated gene expression counts from each individual decoder.
- Return type:
A tuple containing
Examples
>>> n_generated_cells = 10 >>> desired_covariates = {"cell_type": "T-cell", "condition": "diseased"} >>> avg_expr, all_expr = model.generate( ... n_samples=n_generated_cells, ... covariates=desired_covariates, ... adata_for_library_size=adata ... )
To compute a custom average using only a subset of decoders (e.g., decoder_0 and decoder_2): >>> import numpy as np >>> custom_avg = np.mean( … [all_expr[“decoder_0”], all_expr[“decoder_2”]], axis=0 … )
- add_perturbation_embedding(name: str, embedding: numpy.ndarray) None[source]
Register an embedding vector for an unseen atomic perturbation.
After calling this method the model can predict counterfactuals for the new perturbation (or combinatorial perturbations that include it).
- Parameters:
name – Atomic perturbation name (e.g.
"GeneC").embedding – Vector representation with the same dimensionality as the predefined embeddings used during training.
- get_perturbation_embeddings() Tuple[List[str], numpy.ndarray][source]
Return current perturbation names and their (combined) embeddings.
- Returns:
names – Ordered list of perturbation category names.
embeddings – Numpy array of shape
(n_perturbations, emb_dim).
- predict_perturbation(adata: anndata.AnnData, perturbation: str, source_perturbation: str, cats: List[str], perturbation_key: str, new_embeddings: dict | None = None, n_samples_from_source: int | None = None, seed: int | None = 0, device: str | torch.device | None = None, batch_size: int | None = 256, source_adata: anndata.AnnData | None = None) Tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]
Predict gene expression after a perturbation.
Memory-efficient implementation that avoids copying large adata objects.
- Parameters:
adata – AnnData object containing cells. Must have a
'counts'layer.perturbation – Target perturbation label (e.g.
"GeneA"or"GeneA+GeneB").source_perturbation – Source / control perturbation label (e.g.
"ctrl").cats – List of categorical covariate keys (same order used during training).
perturbation_key – Which element of
catsis the perturbation covariate.new_embeddings – Optional dictionary mapping atomic perturbation names to embedding vectors.
n_samples_from_source – If set, randomly sample this many cells from the source group.
seed – Random seed for reproducibility.
device – Device to run prediction on. If
None, uses the model’s device.batch_size – Batch size for forward passes. Default 256.
source_adata – Optional AnnData with control cells.
- Returns:
x_ctrl – Source (control) counts, shape
(n_source, n_genes).x_true – Target (ground-truth) counts if cells exist, else
None.x_pred – Predicted counts, shape
(n_source, n_genes).
- predict_perturbations(adata: anndata.AnnData, perturbations: List[str], source_perturbation: str, cats: List[str], perturbation_key: str, new_embeddings: dict | None = None, n_samples_from_source: int | None = None, seed: int | None = 0, device: str | torch.device | None = None, batch_size: int | None = 256, source_adata: anndata.AnnData | None = None) dict
Predict gene expression for multiple perturbations.
- Parameters:
adata – AnnData object containing cells. Must have a
'counts'layer.perturbations – List of target perturbation labels.
source_perturbation – Source / control perturbation label.
cats – List of categorical covariate keys.
perturbation_key – Which element of
catsis the perturbation covariate.new_embeddings – Optional dictionary mapping atomic perturbation names to embedding vectors.
n_samples_from_source – If set, randomly sample this many cells from the source group.
seed – Random seed for reproducibility.
device – Device to run prediction on.
batch_size – Batch size for forward passes. Default 256.
source_adata – Optional AnnData with control cells.
- Returns:
Dictionary mapping each perturbation to
(x_ctrl, x_true, x_pred).- Return type: