celldisect package

Submodules

celldisect.data module

class celldisect.data.AnnDataSplitter(*args: Any, **kwargs: Any)[source]

Bases: DataSplitter

setup(stage: str | None = None)[source]
train_dataloader()[source]
val_dataloader()[source]
test_dataloader()[source]

celldisect.trainingplan module

class celldisect.trainingplan.CellDISECTTrainingPlan(*args: Any, **kwargs: Any)[source]

Bases: TrainingPlan

Train VAEs with adversarial loss option to encourage latent space mixing.

Parameters:
  • module (BaseModuleClass) – A module instance from class BaseModuleClass.

  • recon_weight (Tunable[Union[float, int]]) – Weight for the reconstruction loss of X.

  • cf_weight (Tunable[Union[float, int]]) – Weight for the reconstruction loss of X_cf.

  • beta (Tunable[Union[float, int]]) – Weight for the KL divergence of Zi.

  • clf_weight (Tunable[Union[float, int]]) – Weight for the Si classifier loss.

  • adv_clf_weight (Tunable[Union[float, int]]) – Weight for the adversarial classifier loss.

  • adv_period (Tunable[int]) – Adversarial training period.

  • n_cf (Tunable[int]) – Number of X_cf reconstructions (a random permutation of n VAEs and a random half-batch subset for each trial).

  • optimizer (Tunable[Literal["Adam", "AdamW", "Custom"]], optional) – One of “Adam” (Adam), “AdamW” (AdamW), or “Custom”, which requires a custom optimizer creator callable to be passed via optimizer_creator. Default is “Adam”.

  • optimizer_creator (Optional[TorchOptimizerCreator], optional) – A callable taking in parameters and returning a Optimizer. This allows using any PyTorch optimizer with custom hyperparameters. Default is None.

  • lr (Tunable[float], optional) – Learning rate used for optimization, when optimizer_creator is None. Default is 1e-3.

  • weight_decay (Tunable[float], optional) – Weight decay used in optimization, when optimizer_creator is None. Default is 1e-6.

  • n_steps_kl_warmup (Tunable[int], optional) – Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when n_epochs_kl_warmup is set to None. Default is None.

  • n_epochs_kl_warmup (Tunable[int], optional) – Number of epochs to scale weight on KL divergences from 0 to 1. Overrides n_steps_kl_warmup when both are not None. Default is 400.

  • n_epochs_pretrain_ae (Tunable[int], optional) – Number of epochs to pretrain the autoencoder. Default is 0.

  • reduce_lr_on_plateau (Tunable[bool], optional) – Whether to monitor validation loss and reduce learning rate when validation set lr_scheduler_metric plateaus. Default is True.

  • lr_factor (Tunable[float], optional) – Factor to reduce learning rate. Default is 0.6.

  • lr_patience (Tunable[int], optional) – Number of epochs with no improvement after which learning rate will be reduced. Default is 30.

  • lr_threshold (Tunable[float], optional) – Threshold for measuring the new optimum. Default is 0.0.

  • lr_scheduler_metric (Literal["loss_validation"], optional) – Which metric to track for learning rate reduction. Default is “loss_validation”.

  • lr_min (float, optional) – Minimum learning rate allowed. Default is 0.

  • scale_adversarial_loss (Union[float, Literal["auto"]], optional) – Scaling factor on the adversarial components of the loss. By default, adversarial loss is scaled from 1 to 0 following opposite of kl warmup. Default is “auto”.

  • ensemble_method_cf (bool, optional) – Whether to use the new counterfactual method. Default is True.

  • kappa_optimizer2 (bool, optional) – Whether to use the second kappa optimizer. Default is True.

  • **loss_kwargs – Keyword args to pass to the loss method of the module. kl_weight should not be passed here and is handled automatically.

initialize_train_metrics()[source]

Initialize train related metrics.

initialize_val_metrics()[source]

Initialize val related metrics.

compute_and_log_metrics(loss_output: dict, metrics: Dict[str, scvi.train._metrics.ElboMetric], mode: str)

Computes and logs metrics.

This function updates the provided metrics dictionary with the values from the loss output and logs them using the appropriate logging method.

Parameters:
  • loss_output (dict) – Dictionary containing the loss output from the scvi-tools module.

  • metrics (Dict[str, ElboMetric]) – Dictionary of metrics to update.

  • mode (str) – Postfix string to add to the metric name for extra metrics.

adv_classifier_metrics(inference_outputs, detach_z=True)[source]

Computes the loss for the adversarial classifier.

This function calculates the classification metrics for the adversarial classifier using the provided inference outputs.

Parameters:
  • inference_outputs (dict) – Dictionary containing the outputs from the inference step.

  • detach_z (bool, optional) – Whether to detach the latent representation z, by default True.

Returns:

A tuple containing the mean CE loss, accuracy, and F1 score.

Return type:

tuple

training_step(batch, batch_idx)[source]

Training step for adversarial training.

validation_step(batch, batch_idx)[source]

Validation step.

on_train_epoch_end()[source]

Update the learning rate via scheduler steps.

on_validation_epoch_end() None[source]
configure_optimizers()[source]

Configure optimizers for adversarial training.

celldisect.tuner_base module

celldisect.utils module

class celldisect.utils.TRAIN_MODE(value)[source]

Bases: int, Enum

An enumeration.

RECONST = 0
RECONST_CF = 1
KL_Z = 2
CLASSIFICATION = 3
ADVERSARIAL = 4
__format__(format_spec)

Returns format using actual value type unless __str__ has been overridden.

class celldisect.utils.LOSS_KEYS(value)[source]

Bases: str, Enum

An enumeration.

LOSS = 'loss'
RECONST_LOSS_X = 'rec_x'
RECONST_LOSS_X_CF = 'rec_x_cf'
KL_Z = 'kl_z'
CLASSIFICATION_LOSS = 'ce'
ACCURACY = 'acc'
F1 = 'f1'
__format__(format_spec)

Returns format using actual value type unless __str__ has been overridden.

celldisect.utils.one_hot_cat(n_cat_list: List[int], cat_covs: torch.Tensor)[source]
celldisect.utils.parse_perturbation(name: str, delimiter: str = '+') List[str][source]

Split a (possibly combinatorial) perturbation name into atomic components.

Parameters:
  • name – Perturbation label, e.g. "GeneA+GeneB" or "ctrl".

  • delimiter – Separator used for combinatorial perturbations.

Return type:

List of atomic perturbation names.

celldisect.utils.validate_perturbation_embeddings(adata, perturbation_key: str, embedding_key: str, delimiter: str = '+') None[source]

Check that every atomic perturbation in adata.obs has an entry in adata.uns.

Parameters:
  • adata – Annotated data object.

  • perturbation_key – Column in adata.obs containing perturbation labels.

  • embedding_key – Key in adata.uns whose value is a dict[str, array] mapping atomic perturbation names to their vector representations.

  • delimiter – Separator used for combinatorial perturbations.

Raises:
  • KeyError – If embedding_key is not found in adata.uns.

  • ValueError – If any atomic perturbation is missing from the embeddings dictionary.

celldisect.utils.build_perturbation_embedding_matrix(category_names: List[str], predefined_embeddings: Dict[str, numpy.ndarray], delimiter: str = '+') torch.Tensor[source]

Build an embedding matrix for all perturbation categories.

For combinatorial perturbations the component embeddings are summed.

Parameters:
  • category_names – Ordered list of perturbation category names (from AnnData mapping).

  • predefined_embeddings – Dictionary mapping atomic perturbation names to vectors.

  • delimiter – Separator for combinatorial perturbations.

Return type:

Tensor of shape (n_categories, emb_dim).

celldisect.utils.perturbation_metrics(pred: numpy.ndarray, true: numpy.ndarray, ctrl: numpy.ndarray, top_n_de: int = 20) dict[source]

Compute standard perturbation prediction evaluation metrics.

Parameters:
  • pred – Predicted mean gene expression, shape (n_genes,) or (n_cells, n_genes). If 2-D the mean across cells is taken.

  • true – Ground-truth mean gene expression (same shape convention).

  • ctrl – Control (source) mean gene expression (same shape convention).

  • top_n_de – Number of top differentially expressed genes to evaluate.

Returns:

  • Dictionary with the following keys

    • pearson_mean – Pearson r between predicted and true mean expression

    • pearson_delta – Pearson r between predicted and true delta (vs ctrl)

    • mse – Mean squared error of mean expression

    • top_de_pearson – Pearson r on top-N DE genes (ranked by |true - ctrl|)

    • top_de_cosine – Cosine similarity on top-N DE genes

class celldisect.utils.PerturbationNetwork(*args: Any, **kwargs: Any)[source]

Bases: Module

forward(perts, dosages)[source]

perts: (batch_size, max_comb_len) dosages: (batch_size, max_comb_len)

Module contents

class celldisect.CellDISECT(*args: Any, **kwargs: Any)[source]

Bases: RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClass, TunableMixin

CellDISECT model for single-cell RNA sequencing data analysis.

Parameters:
  • adata – AnnData object that has been registered via setup_anndata().

  • n_hidden – Number of nodes per hidden layer.

  • n_latent_shared – Dimensionality of the shared latent space.

  • n_latent_attribute – Dimensionality of the latent space for each sensitive attribute.

  • n_layers – Number of hidden layers used for encoder and decoder neural networks.

  • dropout_rate – Dropout rate for neural networks.

  • gene_likelihood

    One of:

    • 'nb' - Negative binomial distribution

    • 'zinb' - Zero-inflated negative binomial distribution

    • 'poisson' - Poisson distribution

  • latent_distribution

    One of:

    • 'normal' - Normal distribution

    • 'ln' - Logistic normal distribution (Normal(0, I) transformed by softmax)

  • split_key – Key in adata.obs to split the data into training, validation, and test sets.

  • train_split – Values in split_key to be used for training.

  • valid_split – Values in split_key to be used for validation.

  • test_split – Values in split_key to be used for testing.

  • weighted_classifier – Whether to use weighted classifiers for categorical covariates.

  • **model_kwargs – Additional keyword arguments for the model.

classmethod setup_anndata(adata: anndata.AnnData, layer: str | None = None, batch_key: str | None = None, labels_key: str | None = None, size_factor_key: str | None = None, categorical_covariate_keys: List[str] | None = None, continuous_covariate_keys: List[str] | None = None, add_cluster_covariate: bool = False, clustering_normalize_counts: bool = True, perturbation_key: str | None = None, perturbation_embedding_key: str | None = None, perturbation_combination_delimiter: str = '+', **kwargs)

Set up the AnnData object for the CellDISECT model.

This method configures the AnnData object by registering the necessary fields and optionally adding a cluster covariate. When perturbation_key is provided, the corresponding column in adata.obs is treated as a perturbation covariate whose embeddings come from adata.uns[perturbation_embedding_key] rather than being learned during training.

Parameters:
  • adata (AnnData) – AnnData object to be set up.

  • layer (Optional[str], optional) – Layer in adata to use as the count data, by default None.

  • batch_key (Optional[str], optional) – Key in adata.obs for batch information, by default None.

  • labels_key (Optional[str], optional) – Key in adata.obs for labels, by default None.

  • size_factor_key (Optional[str], optional) – Key in adata.obs for size factors, by default None.

  • categorical_covariate_keys (Optional[List[str]], optional) – List of keys in adata.obs for categorical covariates, by default None.

  • continuous_covariate_keys (Optional[List[str]], optional) – List of keys in adata.obs for continuous covariates, by default None.

  • add_cluster_covariate (bool, optional) – Whether to add a cluster covariate to adata.obs, by default False.

  • clustering_normalize_counts (bool, optional) – Whether to normalize counts before clustering, by default True.

  • perturbation_key (Optional[str], optional) – Column in adata.obs that contains perturbation labels (e.g. "GeneA", "GeneA+GeneB"). When set, the perturbation covariate uses predefined embeddings instead of learned ones.

  • perturbation_embedding_key (Optional[str], optional) – Key in adata.uns whose value is a dict[str, np.ndarray] mapping atomic perturbation names to their vector representations (e.g. ESM or GenePT embeddings). Required when perturbation_key is set.

  • perturbation_combination_delimiter (str, optional) – Delimiter for combinatorial perturbation labels, by default "+".

  • **kwargs – Additional keyword arguments.

Return type:

None

classmethod add_cluster_covariate(adata: anndata.AnnData, normalize_counts: bool = True)[source]

Run PCA on the gene expression matrix and run Leiden clustering on the PCA components to create a cluster covariate to be added to the adata.obs.

Parameters:
  • adata (AnnData) – AnnData object containing the single-cell RNA sequencing data.

  • normalize_counts (bool, optional) – If True, takes the counts from the adata.layers[‘counts’] and log normalizes them, by default True.

Return type:

None

predict_given_covs_depricated(adata: anndata.AnnData, cats: List[str], cov_idx: int, cov_value_cf, batch_size: int | None = None) Tuple[torch.Tensor, torch.Tensor]
predict_counterfactuals(adata: anndata.AnnData, cov_names: list[str], cov_values: list[str], cov_values_cf: list[str], cats: list[str], n_samples_from_source: int | None = None, seed: int | None = 0)

Predicts counterfactuals for a given subset of data.

This function estimates the counterfactual outcomes for a subset of data based on specified changes in covariate values.

Parameters:
  • adata (AnnData) – The subset of the data for which the counterfactuals are to be predicted.

  • cov_names (list[str]) – Names of the covariates that are to be changed.

  • cov_values (list[str]) – Original values for the covariates that are to be changed.

  • cov_values_cf (list[str]) – Counterfactual values for the covariates that are to be changed.

  • cats (list[str]) – Names of the categorical covariates.

  • n_samples_from_source (Optional[int], optional) – Number of samples to take from the source data to predict the counterfactuals. If None, all samples from the source data are used. Defaults to None.

  • seed (Optional[int], optional) – Random seed for reproducibility. Defaults to 0. Only used if n_samples_from_source is not None.

Returns:

Control (source), True Counterfactuals, and Predicted Counterfactual COUNTS (not log-transformed).

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Examples

Single covariate change:

cats = ['cell_type', 'condition']
cell_type_to_check = ['CD4 T',]
cov_names = ['condition']
cov_values = ['ctrl']
cov_values_cf = ['stimulated']
n_samples_from_source = 500
x_ctrl, x_true, x_pred = model.predict_counterfactuals(
    adata[(adata.obs['cell_type'].isin(cell_type_to_check))].copy(),
    cov_names=cov_names,
    cov_values=cov_values,
    cov_values_cf=cov_values_cf,
    cats=cats,
    n_samples_from_source=n_samples_from_source,
)

Multiple covariate change:

cats = ['tissue', 'Sample ID', 'sex', 'Age_bin', 'CoarseCellType']
cell_type_to_check = 'Epithelial cell (luminal)'
cov_names = ['sex', 'tissue']
cov_values = ['female', 'breast']
cov_values_cf = ['male', 'prostate gland']
n_samples_from_source = None
x_ctrl, x_true, x_pred = model.predict_counterfactuals(
    adata[adata.obs['Broad cell type'] == cell_type_to_check].copy(),
    cov_names=cov_names,
    cov_values=cov_values,
    cov_values_cf=cov_values_cf,
    cats=cats,
    n_samples_from_source=n_samples_from_source,
)
get_latent_representation(adata: anndata.AnnData | None = None, indices: Sequence[int] | None = None, batch_size: int | None = None, nullify_cat_covs_indices: List[int] | None = None, nullify_shared: bool | None = False) numpy.ndarray | Tuple[numpy.ndarray, numpy.ndarray]

Get the latent representation of the data.

This function computes the latent representation of the data using the trained model. It allows for optional nullification of specific categorical covariates or the shared latent space.

Parameters:
  • adata (Optional[AnnData]) – Annotated data object. If None, uses the data registered with the model.

  • indices (Optional[Sequence[int]]) – Optional indices to subset the data.

  • batch_size (Optional[int]) – Batch size to use for data loading. If None, uses the default batch size.

  • nullify_cat_covs_indices (Optional[List[int]]) – List of indices of categorical covariates to nullify in the latent space. If None, no covariates are nullified.

  • nullify_shared (Optional[bool]) – If True, nullifies the shared latent space. Defaults to False.

Returns:

The latent representation of the data. If nullify_cat_covs_indices or nullify_shared is specified, returns a tuple of the latent representation and the nullified latent representation.

Return type:

Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]

get_cat_covariate_latents()

Returns the embeddings of the categorical covariates.

This function retrieves the embeddings of the categorical covariates from the trained model. It also returns the mappings of the categorical covariates.

Parameters:

self (object) – The instance of the class.

Returns:

A tuple containing two dictionaries: - covar_embeddings: Dictionary where keys are covariate names and values are the embeddings as numpy arrays. - covar_mappings: Dictionary where keys are covariate names and values are the mappings of the covariates.

Return type:

Tuple[dict, dict]

get_gene_importance() pandas.DataFrame

Computes gene importance for each encoder.

This method calculates a gene importance score for each gene based on the weights of the encoders. The importance is calculated by propagating the weights through the network layers to obtain an effective weight matrix from input genes to the latent space. This method is most accurate when the model is trained without biases and with a non-negative activation function like ReLU. If biases are present, they are ignored in this calculation.

Returns:

A pandas DataFrame containing gene importance scores. The rows correspond to genes, and the columns correspond to the different encoders (e.g., ‘shared’, ‘attribute_0’, ‘attribute_1’, …).

Return type:

pd.DataFrame

train(max_epochs: int | None = None, use_gpu: str | int | bool | None = True, train_size: float = 0.8, validation_size: float | None = None, batch_size: int = 256, early_stopping: bool = True, save_best: bool = False, plan_kwargs: dict | None = None, recon_weight: scvi.autotune._types.Tunable.typing.Union[float, int] = 10, cf_weight: scvi.autotune._types.Tunable.typing.Union[float, int] = 1, beta: scvi.autotune._types.Tunable.typing.Union[float, int] = 1, clf_weight: scvi.autotune._types.Tunable.typing.Union[float, int] = 50, adv_clf_weight: scvi.autotune._types.Tunable.typing.Union[float, int] = 10, adv_period: scvi.autotune._types.Tunable.<class 'int'> = 1, n_cf: scvi.autotune._types.Tunable.<class 'int'> = 10, kappa_optimizer2: bool = True, n_epochs_pretrain_ae: int = 0, **trainer_kwargs)[source]

Train the model.

Parameters:
  • max_epochs (Optional[int]) – Number of passes through the dataset. If None, defaults to np.min([round((20000 / n_cells) * 400), 400]).

  • use_gpu (Optional[Union[str, int, bool]]) – Whether to use GPU for training. Can be a boolean, string, or integer specifying the GPU device.

  • train_size (float) – Size of the training set in the range [0.0, 1.0].

  • validation_size (Optional[float]) – Size of the validation set. If None, defaults to 1 - train_size. If train_size + validation_size < 1, the remaining cells belong to a test set.

  • batch_size (int) – Minibatch size to use during training.

  • early_stopping (bool) – Perform early stopping. Additional arguments can be passed in **kwargs. See Trainer for further options.

  • save_best (bool) – Save the best model state with respect to the validation loss (default), or use the final state in the training procedure.

  • plan_kwargs (Optional[dict]) – Keyword arguments for TrainingPlan. Keyword arguments passed to train() will overwrite values present in plan_kwargs, when appropriate.

  • recon_weight (Tunable[Union[float, int]]) – Weight for the reconstruction loss of X.

  • cf_weight (Tunable[Union[float, int]]) – Weight for the reconstruction loss of X_cf.

  • beta (Tunable[Union[float, int]]) – Weight for the KL divergence of Zi.

  • clf_weight (Tunable[Union[float, int]]) – Weight for the Si classifier loss.

  • adv_clf_weight (Tunable[Union[float, int]]) – Weight for the adversarial classifier loss.

  • adv_period (Tunable[int]) – Adversarial training period.

  • n_cf (Tunable[int]) – Number of X_cf reconstructions (a random permutation of n VAEs and a random half-batch subset for each trial).

  • kappa_optimizer2 (bool) – Whether to use the second kappa optimizer.

  • n_epochs_pretrain_ae (int) – Number of epochs to pretrain the autoencoder.

  • **trainer_kwargs – Other keyword arguments for Trainer.

Return type:

None

generate(n_samples: int, covariates: dict[str, str], library_size: float | None = None, adata_for_library_size: anndata.AnnData | None = None, random_seed: int | None = None, include_shared_latent: bool = False)

Generate new cells conditionally based on specified categorical covariates.

This method generates n_samples new cells for a given combination of categorical covariates. The generation process involves sampling from the latent spaces and then decoding to get gene expression profiles.

The shared latent space (z_shared) is sampled from a standard normal distribution. The attribute-specific latent spaces (zs) are sampled from their respective prior encoders, conditioned on the provided covariates.

Parameters:
  • n_samples – Number of cells to generate.

  • covariates – A dictionary mapping categorical covariate names to the desired values for generation. All categorical covariates used to train the model must be specified.

  • library_size – A specific library size to use for generation. Mutually exclusive with adata_for_library_size.

  • adata_for_library_size – An AnnData object to infer the median library size from. If cells with the specified covariates combination exist in this AnnData object, their median library size is used. Otherwise, the median library size of the entire adata_for_library_size is used. Mutually exclusive with library_size.

  • random_seed – A random seed for reproducibility.

  • include_shared_latent – Whether to include the expression from the decoder associated with the shared latent space (Z_0) in the average_expression. If False, the average is computed over attribute-specific decoders only. We suggest setting this to False for most use cases. The shared decoder is assumed to be decoder_0.

Returns:

  • average_expression (numpy.ndarray): An array of shape (n_samples, n_vars) representing the average gene expression across the selected decoders.

  • all_expressions (dict[str, numpy.ndarray]): A dictionary where keys are decoder names (decoder_0, decoder_1, etc.) and values are numpy arrays of shape (n_samples, n_vars) representing the generated gene expression counts from each individual decoder.

Return type:

A tuple containing

Examples

>>> n_generated_cells = 10
>>> desired_covariates = {"cell_type": "T-cell", "condition": "diseased"}
>>> avg_expr, all_expr = model.generate(
...     n_samples=n_generated_cells,
...     covariates=desired_covariates,
...     adata_for_library_size=adata
... )

To compute a custom average using only a subset of decoders (e.g., decoder_0 and decoder_2): >>> import numpy as np >>> custom_avg = np.mean( … [all_expr[“decoder_0”], all_expr[“decoder_2”]], axis=0 … )

add_perturbation_embedding(name: str, embedding: numpy.ndarray) None[source]

Register an embedding vector for an unseen atomic perturbation.

After calling this method the model can predict counterfactuals for the new perturbation (or combinatorial perturbations that include it).

Parameters:
  • name – Atomic perturbation name (e.g. "GeneC").

  • embedding – Vector representation with the same dimensionality as the predefined embeddings used during training.

get_perturbation_embeddings() Tuple[List[str], numpy.ndarray][source]

Return current perturbation names and their (combined) embeddings.

Returns:

  • names – Ordered list of perturbation category names.

  • embeddings – Numpy array of shape (n_perturbations, emb_dim).

predict_perturbation(adata: anndata.AnnData, perturbation: str, source_perturbation: str, cats: List[str], perturbation_key: str, new_embeddings: dict | None = None, n_samples_from_source: int | None = None, seed: int | None = 0, device: str | torch.device | None = None, batch_size: int | None = 256, source_adata: anndata.AnnData | None = None) Tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]

Predict gene expression after a perturbation.

Memory-efficient implementation that avoids copying large adata objects.

Parameters:
  • adata – AnnData object containing cells. Must have a 'counts' layer.

  • perturbation – Target perturbation label (e.g. "GeneA" or "GeneA+GeneB").

  • source_perturbation – Source / control perturbation label (e.g. "ctrl").

  • cats – List of categorical covariate keys (same order used during training).

  • perturbation_key – Which element of cats is the perturbation covariate.

  • new_embeddings – Optional dictionary mapping atomic perturbation names to embedding vectors.

  • n_samples_from_source – If set, randomly sample this many cells from the source group.

  • seed – Random seed for reproducibility.

  • device – Device to run prediction on. If None, uses the model’s device.

  • batch_size – Batch size for forward passes. Default 256.

  • source_adata – Optional AnnData with control cells.

Returns:

  • x_ctrl – Source (control) counts, shape (n_source, n_genes).

  • x_true – Target (ground-truth) counts if cells exist, else None.

  • x_pred – Predicted counts, shape (n_source, n_genes).

predict_perturbations(adata: anndata.AnnData, perturbations: List[str], source_perturbation: str, cats: List[str], perturbation_key: str, new_embeddings: dict | None = None, n_samples_from_source: int | None = None, seed: int | None = 0, device: str | torch.device | None = None, batch_size: int | None = 256, source_adata: anndata.AnnData | None = None) dict

Predict gene expression for multiple perturbations.

Parameters:
  • adata – AnnData object containing cells. Must have a 'counts' layer.

  • perturbations – List of target perturbation labels.

  • source_perturbation – Source / control perturbation label.

  • cats – List of categorical covariate keys.

  • perturbation_key – Which element of cats is the perturbation covariate.

  • new_embeddings – Optional dictionary mapping atomic perturbation names to embedding vectors.

  • n_samples_from_source – If set, randomly sample this many cells from the source group.

  • seed – Random seed for reproducibility.

  • device – Device to run prediction on.

  • batch_size – Batch size for forward passes. Default 256.

  • source_adata – Optional AnnData with control cells.

Returns:

Dictionary mapping each perturbation to (x_ctrl, x_true, x_pred).

Return type:

dict

class celldisect.CellDISECTModule(*args: Any, **kwargs: Any)[source]

Bases: BaseModuleClass

Variational auto-encoder module.

Parameters:
  • n_input (int) – Number of input genes.

  • n_hidden (Tunable[int], optional) – Number of nodes per hidden layer, by default 128.

  • n_latent_shared (Tunable[int], optional) – Dimensionality of the shared latent space (Z_{-s}), by default 10.

  • n_latent_attribute (Tunable[int], optional) – Dimensionality of the latent space for each sensitive attribute (Z_{s_i}), by default 10.

  • n_layers (Tunable[int], optional) – Number of hidden layers used for encoder and decoder NNs, by default 1.

  • n_cats_per_cov (Optional[Iterable[int]], optional) – Number of categories for each extra categorical covariate, by default None.

  • dropout_rate (Tunable[float], optional) – Dropout rate for neural networks, by default 0.1.

  • log_variational (bool, optional) – Log(data+1) prior to encoding for numerical stability. Not normalization, by default True.

  • gene_likelihood (Tunable[Literal["zinb", "nb", "poisson"]], optional) – One of ‘nb’ (Negative binomial distribution), ‘zinb’ (Zero-inflated negative binomial distribution), or ‘poisson’ (Poisson distribution), by default “zinb”.

  • latent_distribution (Tunable[Literal["normal", "ln"]], optional) – One of ‘normal’ (Isotropic normal) or ‘ln’ (Logistic normal with normal params N(0, 1)), by default “normal”.

  • deeply_inject_covariates (Tunable[bool], optional) – Whether to concatenate covariates into output of hidden layers in encoder/decoder. This option only applies when n_layers > 1. The covariates are concatenated to the input of subsequent hidden layers, by default True.

  • use_batch_norm (Tunable[Literal["encoder", "decoder", "none", "both"]], optional) – Whether to use batch norm in layers, by default “both”.

  • use_layer_norm (Tunable[Literal["encoder", "decoder", "none", "both"]], optional) – Whether to use layer norm in layers, by default “none”.

  • var_activation (Optional[Callable], optional) – Callable used to ensure positivity of the variational distributions’ variance. When None, defaults to torch.exp, by default None.

  • use_custom_embs (bool, optional) – Whether to use custom embeddings, by default False.

  • embeddings (Union[torch.Tensor, List[torch.Tensor]], optional) – Custom embeddings to use if use_custom_embs is True, by default None.

  • classifier_weights (Optional[list], optional) – Weights for the classifiers, by default None.

  • bias (bool, optional) – Whether to use bias in the encoder and decoder layers, by default True.

  • perturbation_cov_idx (Optional[int], optional) – Index (in the categorical covariates list) of the perturbation covariate. When set, that covariate uses predefined embeddings instead of learned ones.

  • predefined_pert_embeddings (Optional[dict], optional) – Mapping from atomic perturbation name to its vector representation (e.g. ESM or GenePT embeddings). Required when perturbation_cov_idx is not None.

  • perturbation_category_names (Optional[list], optional) – Ordered category names for the perturbation covariate (mirrors the integer-index mapping from the AnnData manager).

  • perturbation_combination_delimiter (str, optional) – Delimiter for combinatorial perturbation labels, by default "+".

inference(x, cat_covs, nullify_cat_covs_indices: List[int] | None = None, nullify_shared: bool | None = False) dict[str, torch.Tensor]

Perform the inference step of the model.

Parameters:
  • x (torch.Tensor) – Input gene expression data.

  • cat_covs (torch.Tensor) – Categorical covariates.

  • nullify_cat_covs_indices (Optional[List[int]], optional) – Indices of categorical covariates to nullify, by default None.

  • nullify_shared (Optional[bool], optional) – Whether to nullify the shared latent space, by default False.

Returns:

Dictionary containing the inference outputs.

Return type:

dict[str, torch.Tensor]

generative(z_shared, zs, library, cat_covs)

Perform the generative step of the model.

Parameters:
  • z_shared (torch.Tensor) – Shared latent space tensor.

  • zs (list of torch.Tensor) – List of latent space tensors for each sensitive attribute.

  • library (torch.Tensor) – Library size tensor.

  • cat_covs (torch.Tensor) – Categorical covariates tensor.

Returns:

Dictionary containing the generative outputs.

Return type:

dict

sub_forward(idx, x, cat_covs, detach_x=False, detach_z=False)[source]

Performs forward (inference + generative) only on encoder/decoder idx.

Parameters:
  • idx (int) – Index of encoder/decoder in [1, …, self.zs_num].

  • x (torch.Tensor) – Input gene expression data.

  • cat_covs (torch.Tensor) – Categorical covariates.

  • detach_x (bool, optional) – Whether to detach the input tensor x, by default False.

  • detach_z (bool, optional) – Whether to detach the latent representation z, by default False.

Returns:

The reconstructed gene expression distribution.

Return type:

torch.distributions.Distribution

classification_logits(inference_outputs)[source]

Compute classification logits for each sensitive attribute.

Parameters:

inference_outputs (dict[str, torch.Tensor]) – Dictionary containing the outputs from the inference step.

Returns:

List of logits for each sensitive attribute.

Return type:

list[torch.Tensor]

sub_forward_cf(idx, x, cat_covs, cat_covs_cf=None, detach_x=False, detach_z=False)[source]

Perform counterfactual forward pass for a specific encoder/decoder.

Parameters:
  • idx (int) – Index of the encoder/decoder to use.

  • x (torch.Tensor) – Input gene expression data.

  • cat_covs (torch.Tensor) – Original categorical covariates.

  • cat_covs_cf (torch.Tensor, optional) – Counterfactual categorical covariates. If None, use original covariates.

  • detach_x (bool, optional) – Whether to detach the input tensor x.

  • detach_z (bool, optional) – Whether to detach the latent representation z.

Returns:

The reconstructed gene expression distribution.

Return type:

torch.distributions.Distribution

sub_forward_cf_z0(x, cat_covs, cat_covs_cf=None, detach_x=False, detach_z=False)[source]

Performs counterfactual forward (inference + generative) only on encoder/decoder 0.

Parameters:
  • x (torch.Tensor) – Input gene expression data.

  • cat_covs (torch.Tensor) – Original categorical covariates.

  • cat_covs_cf (torch.Tensor, optional) – Counterfactual categorical covariates. If None, use original covariates.

  • detach_x (bool, optional) – Whether to detach the input tensor x, by default False.

  • detach_z (bool, optional) – Whether to detach the latent representation z, by default False.

Returns:

The reconstructed gene expression distribution.

Return type:

torch.distributions.Distribution

sub_forward_cf_avg(x, cat_covs, cat_covs_cf=None, detach_x=False, detach_z=False)[source]

Perform counterfactual forward pass for all encoders/decoders and average the results.

Parameters:
  • x (torch.Tensor) – Input gene expression data.

  • cat_covs (torch.Tensor) – Original categorical covariates.

  • cat_covs_cf (torch.Tensor, optional) – Counterfactual categorical covariates. If None, use original covariates.

  • detach_x (bool, optional) – Whether to detach the input tensor x, by default False.

  • detach_z (bool, optional) – Whether to detach the latent representation z, by default False.

Returns:

  • torch.Tensor – The average counterfactual gene expression.

  • list – List of reconstructed gene expression distributions.

compute_clf_metrics(logits, cat_covs)[source]

Compute classification metrics: Cross-Entropy (CE) loss, Accuracy, and F1 score.

Parameters:
  • logits (list[torch.Tensor]) – List of logits for each sensitive attribute.

  • cat_covs (torch.Tensor) – Tensor containing the categorical covariates.

Returns:

A tuple containing the mean CE loss, accuracy, and F1 score.

Return type:

tuple

loss(tensors, inference_outputs, generative_outputs, recon_weight: scvi.autotune._types.Tunable.typing.Union[float, int], cf_weight: scvi.autotune._types.Tunable.typing.Union[float, int], beta: scvi.autotune._types.Tunable.typing.Union[float, int], clf_weight: scvi.autotune._types.Tunable.typing.Union[float, int], n_cf: scvi.autotune._types.Tunable.<class 'int'>, kl_weight: float = 1.0, ensemble_method_cf=True)[source]

Compute the loss for the model.

Parameters:
  • tensors (dict) – Dictionary containing the input tensors.

  • inference_outputs (dict) – Dictionary containing the outputs from the inference step.

  • generative_outputs (dict) – Dictionary containing the outputs from the generative step.

  • recon_weight (Tunable[Union[float, int]]) – Weight for the reconstruction loss of X.

  • cf_weight (Tunable[Union[float, int]]) – Weight for the reconstruction loss of X_cf.

  • beta (Tunable[Union[float, int]]) – Weight for the KL divergence of Zi.

  • clf_weight (Tunable[Union[float, int]]) – Weight for the Si classifier loss.

  • n_cf (Tunable[int]) – Number of X_cf reconstructions (X_cf = a random permutation of X).

  • kl_weight (float, optional) – Weight for the KL divergence, by default 1.0.

  • ensemble_method_cf (bool, optional) – Whether to use the new counterfactual method, by default True.

Returns:

Dictionary containing the computed losses and metrics.

Return type:

dict

class celldisect.PerturbationEmbedding(*args: Any, **kwargs: Any)[source]

Bases: Module

Embedding module backed by predefined external vectors (e.g. ESM, GenePT).

For combinatorial perturbations (e.g. "GeneA+GeneB"), component embeddings are summed. The weight matrix is frozen by default.

Parameters:
  • predefined_embeddings – Mapping from atomic perturbation name to its vector representation.

  • category_names – Ordered list of category names that mirrors the integer-index mapping produced by scvi-tools’ CategoricalJointObsField.

  • combination_delimiter – String used to split combinatorial perturbation labels.

add_perturbation(name: str, embedding: numpy.ndarray | None = None) int[source]

Register a new perturbation (possibly unseen during training).

Parameters:
  • name – Perturbation label. Can be combinatorial ("A+B").

  • embedding – Vector for an atomic perturbation that is not yet in the predefined dictionary. Ignored for names already known.

Return type:

Integer index of the (new or existing) perturbation.

rebuild_for_mapping(new_category_names: list) None[source]

Rebuild the weight matrix for a different category-to-index mapping.

forward(indices: torch.Tensor) torch.Tensor[source]
property weight

Alias so callers expecting nn.Embedding attributes still work.

celldisect.perturbation_metrics(pred: numpy.ndarray, true: numpy.ndarray, ctrl: numpy.ndarray, top_n_de: int = 20) dict[source]

Compute standard perturbation prediction evaluation metrics.

Parameters:
  • pred – Predicted mean gene expression, shape (n_genes,) or (n_cells, n_genes). If 2-D the mean across cells is taken.

  • true – Ground-truth mean gene expression (same shape convention).

  • ctrl – Control (source) mean gene expression (same shape convention).

  • top_n_de – Number of top differentially expressed genes to evaluate.

Returns:

  • Dictionary with the following keys

    • pearson_mean – Pearson r between predicted and true mean expression

    • pearson_delta – Pearson r between predicted and true delta (vs ctrl)

    • mse – Mean squared error of mean expression

    • top_de_pearson – Pearson r on top-N DE genes (ranked by |true - ctrl|)

    • top_de_cosine – Cosine similarity on top-N DE genes