Model
Model Class
|
CellDISECT model for single-cell RNA sequencing data analysis. |
Perturbation Embedding
|
Embedding module backed by predefined external vectors (e.g. ESM, GenePT). |
Module Components
- class celldisect._model.CellDISECT(*args: Any, **kwargs: Any)[source]
Bases:
RNASeqMixin,VAEMixin,UnsupervisedTrainingMixin,BaseModelClass,TunableMixinCellDISECT model for single-cell RNA sequencing data analysis.
- Parameters:
adata – AnnData object that has been registered via
setup_anndata().n_hidden – Number of nodes per hidden layer.
n_latent_shared – Dimensionality of the shared latent space.
n_latent_attribute – Dimensionality of the latent space for each sensitive attribute.
n_layers – Number of hidden layers used for encoder and decoder neural networks.
dropout_rate – Dropout rate for neural networks.
gene_likelihood –
One of:
'nb'- Negative binomial distribution'zinb'- Zero-inflated negative binomial distribution'poisson'- Poisson distribution
latent_distribution –
One of:
'normal'- Normal distribution'ln'- Logistic normal distribution (Normal(0, I) transformed by softmax)
split_key – Key in adata.obs to split the data into training, validation, and test sets.
train_split – Values in split_key to be used for training.
valid_split – Values in split_key to be used for validation.
test_split – Values in split_key to be used for testing.
weighted_classifier – Whether to use weighted classifiers for categorical covariates.
**model_kwargs – Additional keyword arguments for the model.
- classmethod setup_anndata(adata: anndata.AnnData, layer: str | None = None, batch_key: str | None = None, labels_key: str | None = None, size_factor_key: str | None = None, categorical_covariate_keys: List[str] | None = None, continuous_covariate_keys: List[str] | None = None, add_cluster_covariate: bool = False, clustering_normalize_counts: bool = True, perturbation_key: str | None = None, perturbation_embedding_key: str | None = None, perturbation_combination_delimiter: str = '+', **kwargs)
Set up the AnnData object for the CellDISECT model.
This method configures the AnnData object by registering the necessary fields and optionally adding a cluster covariate. When
perturbation_keyis provided, the corresponding column inadata.obsis treated as a perturbation covariate whose embeddings come fromadata.uns[perturbation_embedding_key]rather than being learned during training.- Parameters:
adata (AnnData) – AnnData object to be set up.
layer (Optional[str], optional) – Layer in adata to use as the count data, by default None.
batch_key (Optional[str], optional) – Key in adata.obs for batch information, by default None.
labels_key (Optional[str], optional) – Key in adata.obs for labels, by default None.
size_factor_key (Optional[str], optional) – Key in adata.obs for size factors, by default None.
categorical_covariate_keys (Optional[List[str]], optional) – List of keys in adata.obs for categorical covariates, by default None.
continuous_covariate_keys (Optional[List[str]], optional) – List of keys in adata.obs for continuous covariates, by default None.
add_cluster_covariate (bool, optional) – Whether to add a cluster covariate to adata.obs, by default False.
clustering_normalize_counts (bool, optional) – Whether to normalize counts before clustering, by default True.
perturbation_key (Optional[str], optional) – Column in
adata.obsthat contains perturbation labels (e.g."GeneA","GeneA+GeneB"). When set, the perturbation covariate uses predefined embeddings instead of learned ones.perturbation_embedding_key (Optional[str], optional) – Key in
adata.unswhose value is adict[str, np.ndarray]mapping atomic perturbation names to their vector representations (e.g. ESM or GenePT embeddings). Required whenperturbation_keyis set.perturbation_combination_delimiter (str, optional) – Delimiter for combinatorial perturbation labels, by default
"+".**kwargs – Additional keyword arguments.
- Return type:
None
- classmethod add_cluster_covariate(adata: anndata.AnnData, normalize_counts: bool = True)[source]
Run PCA on the gene expression matrix and run Leiden clustering on the PCA components to create a cluster covariate to be added to the adata.obs.
- Parameters:
adata (AnnData) – AnnData object containing the single-cell RNA sequencing data.
normalize_counts (bool, optional) – If True, takes the counts from the adata.layers[‘counts’] and log normalizes them, by default True.
- Return type:
None
- predict_given_covs_depricated(adata: anndata.AnnData, cats: List[str], cov_idx: int, cov_value_cf, batch_size: int | None = None) Tuple[torch.Tensor, torch.Tensor]
- predict_counterfactuals(adata: anndata.AnnData, cov_names: list[str], cov_values: list[str], cov_values_cf: list[str], cats: list[str], n_samples_from_source: int | None = None, seed: int | None = 0)
Predicts counterfactuals for a given subset of data.
This function estimates the counterfactual outcomes for a subset of data based on specified changes in covariate values.
- Parameters:
adata (AnnData) – The subset of the data for which the counterfactuals are to be predicted.
cov_names (list[str]) – Names of the covariates that are to be changed.
cov_values (list[str]) – Original values for the covariates that are to be changed.
cov_values_cf (list[str]) – Counterfactual values for the covariates that are to be changed.
n_samples_from_source (Optional[int], optional) – Number of samples to take from the source data to predict the counterfactuals. If None, all samples from the source data are used. Defaults to None.
seed (Optional[int], optional) – Random seed for reproducibility. Defaults to 0. Only used if n_samples_from_source is not None.
- Returns:
Control (source), True Counterfactuals, and Predicted Counterfactual COUNTS (not log-transformed).
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
Examples
Single covariate change:
cats = ['cell_type', 'condition'] cell_type_to_check = ['CD4 T',] cov_names = ['condition'] cov_values = ['ctrl'] cov_values_cf = ['stimulated'] n_samples_from_source = 500 x_ctrl, x_true, x_pred = model.predict_counterfactuals( adata[(adata.obs['cell_type'].isin(cell_type_to_check))].copy(), cov_names=cov_names, cov_values=cov_values, cov_values_cf=cov_values_cf, cats=cats, n_samples_from_source=n_samples_from_source, )
Multiple covariate change:
cats = ['tissue', 'Sample ID', 'sex', 'Age_bin', 'CoarseCellType'] cell_type_to_check = 'Epithelial cell (luminal)' cov_names = ['sex', 'tissue'] cov_values = ['female', 'breast'] cov_values_cf = ['male', 'prostate gland'] n_samples_from_source = None x_ctrl, x_true, x_pred = model.predict_counterfactuals( adata[adata.obs['Broad cell type'] == cell_type_to_check].copy(), cov_names=cov_names, cov_values=cov_values, cov_values_cf=cov_values_cf, cats=cats, n_samples_from_source=n_samples_from_source, )
- get_latent_representation(adata: anndata.AnnData | None = None, indices: Sequence[int] | None = None, batch_size: int | None = None, nullify_cat_covs_indices: List[int] | None = None, nullify_shared: bool | None = False) numpy.ndarray | Tuple[numpy.ndarray, numpy.ndarray]
Get the latent representation of the data.
This function computes the latent representation of the data using the trained model. It allows for optional nullification of specific categorical covariates or the shared latent space.
- Parameters:
adata (Optional[AnnData]) – Annotated data object. If None, uses the data registered with the model.
indices (Optional[Sequence[int]]) – Optional indices to subset the data.
batch_size (Optional[int]) – Batch size to use for data loading. If None, uses the default batch size.
nullify_cat_covs_indices (Optional[List[int]]) – List of indices of categorical covariates to nullify in the latent space. If None, no covariates are nullified.
nullify_shared (Optional[bool]) – If True, nullifies the shared latent space. Defaults to False.
- Returns:
The latent representation of the data. If nullify_cat_covs_indices or nullify_shared is specified, returns a tuple of the latent representation and the nullified latent representation.
- Return type:
Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]
- get_cat_covariate_latents()
Returns the embeddings of the categorical covariates.
This function retrieves the embeddings of the categorical covariates from the trained model. It also returns the mappings of the categorical covariates.
- Parameters:
self (object) – The instance of the class.
- Returns:
A tuple containing two dictionaries: - covar_embeddings: Dictionary where keys are covariate names and values are the embeddings as numpy arrays. - covar_mappings: Dictionary where keys are covariate names and values are the mappings of the covariates.
- Return type:
- get_gene_importance() pandas.DataFrame
Computes gene importance for each encoder.
This method calculates a gene importance score for each gene based on the weights of the encoders. The importance is calculated by propagating the weights through the network layers to obtain an effective weight matrix from input genes to the latent space. This method is most accurate when the model is trained without biases and with a non-negative activation function like ReLU. If biases are present, they are ignored in this calculation.
- Returns:
A pandas DataFrame containing gene importance scores. The rows correspond to genes, and the columns correspond to the different encoders (e.g., ‘shared’, ‘attribute_0’, ‘attribute_1’, …).
- Return type:
pd.DataFrame
- train(max_epochs: int | None = None, use_gpu: str | int | bool | None = True, train_size: float = 0.8, validation_size: float | None = None, batch_size: int = 256, early_stopping: bool = True, save_best: bool = False, plan_kwargs: dict | None = None, recon_weight: scvi.autotune._types.Tunable.typing.Union[float, int] = 10, cf_weight: scvi.autotune._types.Tunable.typing.Union[float, int] = 1, beta: scvi.autotune._types.Tunable.typing.Union[float, int] = 1, clf_weight: scvi.autotune._types.Tunable.typing.Union[float, int] = 50, adv_clf_weight: scvi.autotune._types.Tunable.typing.Union[float, int] = 10, adv_period: scvi.autotune._types.Tunable.<class 'int'> = 1, n_cf: scvi.autotune._types.Tunable.<class 'int'> = 10, kappa_optimizer2: bool = True, n_epochs_pretrain_ae: int = 0, **trainer_kwargs)[source]
Train the model.
- Parameters:
max_epochs (Optional[int]) – Number of passes through the dataset. If None, defaults to np.min([round((20000 / n_cells) * 400), 400]).
use_gpu (Optional[Union[str, int, bool]]) – Whether to use GPU for training. Can be a boolean, string, or integer specifying the GPU device.
train_size (float) – Size of the training set in the range [0.0, 1.0].
validation_size (Optional[float]) – Size of the validation set. If None, defaults to 1 - train_size. If train_size + validation_size < 1, the remaining cells belong to a test set.
batch_size (int) – Minibatch size to use during training.
early_stopping (bool) – Perform early stopping. Additional arguments can be passed in **kwargs. See
Trainerfor further options.save_best (bool) – Save the best model state with respect to the validation loss (default), or use the final state in the training procedure.
plan_kwargs (Optional[dict]) – Keyword arguments for
TrainingPlan. Keyword arguments passed to train() will overwrite values present in plan_kwargs, when appropriate.recon_weight (Tunable[Union[float, int]]) – Weight for the reconstruction loss of X.
cf_weight (Tunable[Union[float, int]]) – Weight for the reconstruction loss of X_cf.
beta (Tunable[Union[float, int]]) – Weight for the KL divergence of Zi.
clf_weight (Tunable[Union[float, int]]) – Weight for the Si classifier loss.
adv_clf_weight (Tunable[Union[float, int]]) – Weight for the adversarial classifier loss.
adv_period (Tunable[int]) – Adversarial training period.
n_cf (Tunable[int]) – Number of X_cf reconstructions (a random permutation of n VAEs and a random half-batch subset for each trial).
kappa_optimizer2 (bool) – Whether to use the second kappa optimizer.
n_epochs_pretrain_ae (int) – Number of epochs to pretrain the autoencoder.
**trainer_kwargs – Other keyword arguments for
Trainer.
- Return type:
None
- generate(n_samples: int, covariates: dict[str, str], library_size: float | None = None, adata_for_library_size: anndata.AnnData | None = None, random_seed: int | None = None, include_shared_latent: bool = False)
Generate new cells conditionally based on specified categorical covariates.
This method generates n_samples new cells for a given combination of categorical covariates. The generation process involves sampling from the latent spaces and then decoding to get gene expression profiles.
The shared latent space (z_shared) is sampled from a standard normal distribution. The attribute-specific latent spaces (zs) are sampled from their respective prior encoders, conditioned on the provided covariates.
- Parameters:
n_samples – Number of cells to generate.
covariates – A dictionary mapping categorical covariate names to the desired values for generation. All categorical covariates used to train the model must be specified.
library_size – A specific library size to use for generation. Mutually exclusive with adata_for_library_size.
adata_for_library_size – An AnnData object to infer the median library size from. If cells with the specified covariates combination exist in this AnnData object, their median library size is used. Otherwise, the median library size of the entire adata_for_library_size is used. Mutually exclusive with library_size.
random_seed – A random seed for reproducibility.
include_shared_latent – Whether to include the expression from the decoder associated with the shared latent space (Z_0) in the average_expression. If False, the average is computed over attribute-specific decoders only. We suggest setting this to False for most use cases. The shared decoder is assumed to be decoder_0.
- Returns:
average_expression (numpy.ndarray): An array of shape (n_samples, n_vars) representing the average gene expression across the selected decoders.
all_expressions (dict[str, numpy.ndarray]): A dictionary where keys are decoder names (decoder_0, decoder_1, etc.) and values are numpy arrays of shape (n_samples, n_vars) representing the generated gene expression counts from each individual decoder.
- Return type:
A tuple containing
Examples
>>> n_generated_cells = 10 >>> desired_covariates = {"cell_type": "T-cell", "condition": "diseased"} >>> avg_expr, all_expr = model.generate( ... n_samples=n_generated_cells, ... covariates=desired_covariates, ... adata_for_library_size=adata ... )
To compute a custom average using only a subset of decoders (e.g., decoder_0 and decoder_2): >>> import numpy as np >>> custom_avg = np.mean( … [all_expr[“decoder_0”], all_expr[“decoder_2”]], axis=0 … )
- add_perturbation_embedding(name: str, embedding: numpy.ndarray) None[source]
Register an embedding vector for an unseen atomic perturbation.
After calling this method the model can predict counterfactuals for the new perturbation (or combinatorial perturbations that include it).
- Parameters:
name – Atomic perturbation name (e.g.
"GeneC").embedding – Vector representation with the same dimensionality as the predefined embeddings used during training.
- get_perturbation_embeddings() Tuple[List[str], numpy.ndarray][source]
Return current perturbation names and their (combined) embeddings.
- Returns:
names – Ordered list of perturbation category names.
embeddings – Numpy array of shape
(n_perturbations, emb_dim).
- predict_perturbation(adata: anndata.AnnData, perturbation: str, source_perturbation: str, cats: List[str], perturbation_key: str, new_embeddings: dict | None = None, n_samples_from_source: int | None = None, seed: int | None = 0, device: str | torch.device | None = None, batch_size: int | None = 256, source_adata: anndata.AnnData | None = None) Tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]
Predict gene expression after a perturbation.
Memory-efficient implementation that avoids copying large adata objects.
- Parameters:
adata – AnnData object containing cells. Must have a
'counts'layer.perturbation – Target perturbation label (e.g.
"GeneA"or"GeneA+GeneB").source_perturbation – Source / control perturbation label (e.g.
"ctrl").cats – List of categorical covariate keys (same order used during training).
perturbation_key – Which element of
catsis the perturbation covariate.new_embeddings – Optional dictionary mapping atomic perturbation names to embedding vectors.
n_samples_from_source – If set, randomly sample this many cells from the source group.
seed – Random seed for reproducibility.
device – Device to run prediction on. If
None, uses the model’s device.batch_size – Batch size for forward passes. Default 256.
source_adata – Optional AnnData with control cells.
- Returns:
x_ctrl – Source (control) counts, shape
(n_source, n_genes).x_true – Target (ground-truth) counts if cells exist, else
None.x_pred – Predicted counts, shape
(n_source, n_genes).
- predict_perturbations(adata: anndata.AnnData, perturbations: List[str], source_perturbation: str, cats: List[str], perturbation_key: str, new_embeddings: dict | None = None, n_samples_from_source: int | None = None, seed: int | None = 0, device: str | torch.device | None = None, batch_size: int | None = 256, source_adata: anndata.AnnData | None = None) dict
Predict gene expression for multiple perturbations.
- Parameters:
adata – AnnData object containing cells. Must have a
'counts'layer.perturbations – List of target perturbation labels.
source_perturbation – Source / control perturbation label.
cats – List of categorical covariate keys.
perturbation_key – Which element of
catsis the perturbation covariate.new_embeddings – Optional dictionary mapping atomic perturbation names to embedding vectors.
n_samples_from_source – If set, randomly sample this many cells from the source group.
seed – Random seed for reproducibility.
device – Device to run prediction on.
batch_size – Batch size for forward passes. Default 256.
source_adata – Optional AnnData with control cells.
- Returns:
Dictionary mapping each perturbation to
(x_ctrl, x_true, x_pred).- Return type:
- class celldisect._module.PerturbationEmbedding(*args: Any, **kwargs: Any)[source]
Bases:
ModuleEmbedding module backed by predefined external vectors (e.g. ESM, GenePT).
For combinatorial perturbations (e.g.
"GeneA+GeneB"), component embeddings are summed. The weight matrix is frozen by default.- Parameters:
predefined_embeddings – Mapping from atomic perturbation name to its vector representation.
category_names – Ordered list of category names that mirrors the integer-index mapping produced by scvi-tools’
CategoricalJointObsField.combination_delimiter – String used to split combinatorial perturbation labels.
- add_perturbation(name: str, embedding: numpy.ndarray | None = None) int[source]
Register a new perturbation (possibly unseen during training).
- Parameters:
name – Perturbation label. Can be combinatorial (
"A+B").embedding – Vector for an atomic perturbation that is not yet in the predefined dictionary. Ignored for names already known.
- Return type:
Integer index of the (new or existing) perturbation.
- rebuild_for_mapping(new_category_names: list) None[source]
Rebuild the weight matrix for a different category-to-index mapping.
- forward(indices: torch.Tensor) torch.Tensor[source]
- property weight
Alias so callers expecting
nn.Embeddingattributes still work.
- class celldisect._module.CellDISECTModule(*args: Any, **kwargs: Any)[source]
Bases:
BaseModuleClassVariational auto-encoder module.
- Parameters:
n_input (int) – Number of input genes.
n_hidden (Tunable[int], optional) – Number of nodes per hidden layer, by default 128.
n_latent_shared (Tunable[int], optional) – Dimensionality of the shared latent space (Z_{-s}), by default 10.
n_latent_attribute (Tunable[int], optional) – Dimensionality of the latent space for each sensitive attribute (Z_{s_i}), by default 10.
n_layers (Tunable[int], optional) – Number of hidden layers used for encoder and decoder NNs, by default 1.
n_cats_per_cov (Optional[Iterable[int]], optional) – Number of categories for each extra categorical covariate, by default None.
dropout_rate (Tunable[float], optional) – Dropout rate for neural networks, by default 0.1.
log_variational (bool, optional) – Log(data+1) prior to encoding for numerical stability. Not normalization, by default True.
gene_likelihood (Tunable[Literal["zinb", "nb", "poisson"]], optional) – One of ‘nb’ (Negative binomial distribution), ‘zinb’ (Zero-inflated negative binomial distribution), or ‘poisson’ (Poisson distribution), by default “zinb”.
latent_distribution (Tunable[Literal["normal", "ln"]], optional) – One of ‘normal’ (Isotropic normal) or ‘ln’ (Logistic normal with normal params N(0, 1)), by default “normal”.
deeply_inject_covariates (Tunable[bool], optional) – Whether to concatenate covariates into output of hidden layers in encoder/decoder. This option only applies when n_layers > 1. The covariates are concatenated to the input of subsequent hidden layers, by default True.
use_batch_norm (Tunable[Literal["encoder", "decoder", "none", "both"]], optional) – Whether to use batch norm in layers, by default “both”.
use_layer_norm (Tunable[Literal["encoder", "decoder", "none", "both"]], optional) – Whether to use layer norm in layers, by default “none”.
var_activation (Optional[Callable], optional) – Callable used to ensure positivity of the variational distributions’ variance. When None, defaults to torch.exp, by default None.
use_custom_embs (bool, optional) – Whether to use custom embeddings, by default False.
embeddings (Union[torch.Tensor, List[torch.Tensor]], optional) – Custom embeddings to use if use_custom_embs is True, by default None.
classifier_weights (Optional[list], optional) – Weights for the classifiers, by default None.
bias (bool, optional) – Whether to use bias in the encoder and decoder layers, by default True.
perturbation_cov_idx (Optional[int], optional) – Index (in the categorical covariates list) of the perturbation covariate. When set, that covariate uses predefined embeddings instead of learned ones.
predefined_pert_embeddings (Optional[dict], optional) – Mapping from atomic perturbation name to its vector representation (e.g. ESM or GenePT embeddings). Required when
perturbation_cov_idxis not None.perturbation_category_names (Optional[list], optional) – Ordered category names for the perturbation covariate (mirrors the integer-index mapping from the AnnData manager).
perturbation_combination_delimiter (str, optional) – Delimiter for combinatorial perturbation labels, by default
"+".
- inference(x, cat_covs, nullify_cat_covs_indices: List[int] | None = None, nullify_shared: bool | None = False) dict[str, torch.Tensor]
Perform the inference step of the model.
- Parameters:
x (torch.Tensor) – Input gene expression data.
cat_covs (torch.Tensor) – Categorical covariates.
nullify_cat_covs_indices (Optional[List[int]], optional) – Indices of categorical covariates to nullify, by default None.
nullify_shared (Optional[bool], optional) – Whether to nullify the shared latent space, by default False.
- Returns:
Dictionary containing the inference outputs.
- Return type:
- generative(z_shared, zs, library, cat_covs)
Perform the generative step of the model.
- Parameters:
z_shared (torch.Tensor) – Shared latent space tensor.
zs (list of torch.Tensor) – List of latent space tensors for each sensitive attribute.
library (torch.Tensor) – Library size tensor.
cat_covs (torch.Tensor) – Categorical covariates tensor.
- Returns:
Dictionary containing the generative outputs.
- Return type:
- sub_forward(idx, x, cat_covs, detach_x=False, detach_z=False)[source]
Performs forward (inference + generative) only on encoder/decoder idx.
- Parameters:
idx (int) – Index of encoder/decoder in [1, …, self.zs_num].
x (torch.Tensor) – Input gene expression data.
cat_covs (torch.Tensor) – Categorical covariates.
detach_x (bool, optional) – Whether to detach the input tensor x, by default False.
detach_z (bool, optional) – Whether to detach the latent representation z, by default False.
- Returns:
The reconstructed gene expression distribution.
- Return type:
torch.distributions.Distribution
- classification_logits(inference_outputs)[source]
Compute classification logits for each sensitive attribute.
- sub_forward_cf(idx, x, cat_covs, cat_covs_cf=None, detach_x=False, detach_z=False)[source]
Perform counterfactual forward pass for a specific encoder/decoder.
- Parameters:
idx (int) – Index of the encoder/decoder to use.
x (torch.Tensor) – Input gene expression data.
cat_covs (torch.Tensor) – Original categorical covariates.
cat_covs_cf (torch.Tensor, optional) – Counterfactual categorical covariates. If None, use original covariates.
detach_x (bool, optional) – Whether to detach the input tensor x.
detach_z (bool, optional) – Whether to detach the latent representation z.
- Returns:
The reconstructed gene expression distribution.
- Return type:
torch.distributions.Distribution
- sub_forward_cf_z0(x, cat_covs, cat_covs_cf=None, detach_x=False, detach_z=False)[source]
Performs counterfactual forward (inference + generative) only on encoder/decoder 0.
- Parameters:
x (torch.Tensor) – Input gene expression data.
cat_covs (torch.Tensor) – Original categorical covariates.
cat_covs_cf (torch.Tensor, optional) – Counterfactual categorical covariates. If None, use original covariates.
detach_x (bool, optional) – Whether to detach the input tensor x, by default False.
detach_z (bool, optional) – Whether to detach the latent representation z, by default False.
- Returns:
The reconstructed gene expression distribution.
- Return type:
torch.distributions.Distribution
- sub_forward_cf_avg(x, cat_covs, cat_covs_cf=None, detach_x=False, detach_z=False)[source]
Perform counterfactual forward pass for all encoders/decoders and average the results.
- Parameters:
x (torch.Tensor) – Input gene expression data.
cat_covs (torch.Tensor) – Original categorical covariates.
cat_covs_cf (torch.Tensor, optional) – Counterfactual categorical covariates. If None, use original covariates.
detach_x (bool, optional) – Whether to detach the input tensor x, by default False.
detach_z (bool, optional) – Whether to detach the latent representation z, by default False.
- Returns:
torch.Tensor – The average counterfactual gene expression.
list – List of reconstructed gene expression distributions.
- compute_clf_metrics(logits, cat_covs)[source]
Compute classification metrics: Cross-Entropy (CE) loss, Accuracy, and F1 score.
- loss(tensors, inference_outputs, generative_outputs, recon_weight: scvi.autotune._types.Tunable.typing.Union[float, int], cf_weight: scvi.autotune._types.Tunable.typing.Union[float, int], beta: scvi.autotune._types.Tunable.typing.Union[float, int], clf_weight: scvi.autotune._types.Tunable.typing.Union[float, int], n_cf: scvi.autotune._types.Tunable.<class 'int'>, kl_weight: float = 1.0, ensemble_method_cf=True)[source]
Compute the loss for the model.
- Parameters:
tensors (dict) – Dictionary containing the input tensors.
inference_outputs (dict) – Dictionary containing the outputs from the inference step.
generative_outputs (dict) – Dictionary containing the outputs from the generative step.
recon_weight (Tunable[Union[float, int]]) – Weight for the reconstruction loss of X.
cf_weight (Tunable[Union[float, int]]) – Weight for the reconstruction loss of X_cf.
beta (Tunable[Union[float, int]]) – Weight for the KL divergence of Zi.
clf_weight (Tunable[Union[float, int]]) – Weight for the Si classifier loss.
n_cf (Tunable[int]) – Number of X_cf reconstructions (X_cf = a random permutation of X).
kl_weight (float, optional) – Weight for the KL divergence, by default 1.0.
ensemble_method_cf (bool, optional) – Whether to use the new counterfactual method, by default True.
- Returns:
Dictionary containing the computed losses and metrics.
- Return type:
- class celldisect.trainingplan.CellDISECTTrainingPlan(*args: Any, **kwargs: Any)[source]
Bases:
TrainingPlanTrain VAEs with adversarial loss option to encourage latent space mixing.
- Parameters:
module (BaseModuleClass) – A module instance from class
BaseModuleClass.recon_weight (Tunable[Union[float, int]]) – Weight for the reconstruction loss of X.
cf_weight (Tunable[Union[float, int]]) – Weight for the reconstruction loss of X_cf.
beta (Tunable[Union[float, int]]) – Weight for the KL divergence of Zi.
clf_weight (Tunable[Union[float, int]]) – Weight for the Si classifier loss.
adv_clf_weight (Tunable[Union[float, int]]) – Weight for the adversarial classifier loss.
adv_period (Tunable[int]) – Adversarial training period.
n_cf (Tunable[int]) – Number of X_cf reconstructions (a random permutation of n VAEs and a random half-batch subset for each trial).
optimizer (Tunable[Literal["Adam", "AdamW", "Custom"]], optional) – One of “Adam” (
Adam), “AdamW” (AdamW), or “Custom”, which requires a custom optimizer creator callable to be passed via optimizer_creator. Default is “Adam”.optimizer_creator (Optional[TorchOptimizerCreator], optional) – A callable taking in parameters and returning a
Optimizer. This allows using any PyTorch optimizer with custom hyperparameters. Default is None.lr (Tunable[float], optional) – Learning rate used for optimization, when optimizer_creator is None. Default is 1e-3.
weight_decay (Tunable[float], optional) – Weight decay used in optimization, when optimizer_creator is None. Default is 1e-6.
n_steps_kl_warmup (Tunable[int], optional) – Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1. Only activated when n_epochs_kl_warmup is set to None. Default is None.
n_epochs_kl_warmup (Tunable[int], optional) – Number of epochs to scale weight on KL divergences from 0 to 1. Overrides n_steps_kl_warmup when both are not None. Default is 400.
n_epochs_pretrain_ae (Tunable[int], optional) – Number of epochs to pretrain the autoencoder. Default is 0.
reduce_lr_on_plateau (Tunable[bool], optional) – Whether to monitor validation loss and reduce learning rate when validation set lr_scheduler_metric plateaus. Default is True.
lr_factor (Tunable[float], optional) – Factor to reduce learning rate. Default is 0.6.
lr_patience (Tunable[int], optional) – Number of epochs with no improvement after which learning rate will be reduced. Default is 30.
lr_threshold (Tunable[float], optional) – Threshold for measuring the new optimum. Default is 0.0.
lr_scheduler_metric (Literal["loss_validation"], optional) – Which metric to track for learning rate reduction. Default is “loss_validation”.
lr_min (float, optional) – Minimum learning rate allowed. Default is 0.
scale_adversarial_loss (Union[float, Literal["auto"]], optional) – Scaling factor on the adversarial components of the loss. By default, adversarial loss is scaled from 1 to 0 following opposite of kl warmup. Default is “auto”.
ensemble_method_cf (bool, optional) – Whether to use the new counterfactual method. Default is True.
kappa_optimizer2 (bool, optional) – Whether to use the second kappa optimizer. Default is True.
**loss_kwargs – Keyword args to pass to the loss method of the module. kl_weight should not be passed here and is handled automatically.
- initialize_train_metrics()[source]
Initialize train related metrics.
- initialize_val_metrics()[source]
Initialize val related metrics.
- compute_and_log_metrics(loss_output: dict, metrics: Dict[str, scvi.train._metrics.ElboMetric], mode: str)
Computes and logs metrics.
This function updates the provided metrics dictionary with the values from the loss output and logs them using the appropriate logging method.
- adv_classifier_metrics(inference_outputs, detach_z=True)[source]
Computes the loss for the adversarial classifier.
This function calculates the classification metrics for the adversarial classifier using the provided inference outputs.
- training_step(batch, batch_idx)[source]
Training step for adversarial training.
- validation_step(batch, batch_idx)[source]
Validation step.
- on_train_epoch_end()[source]
Update the learning rate via scheduler steps.
- configure_optimizers()[source]
Configure optimizers for adversarial training.