import logging
from typing import List, Literal, Optional, Sequence, Tuple, Union
import random
import numpy as np
import pandas as pd
import torch
from anndata import AnnData
import anndata as ad
import scanpy as sc
from scipy import sparse
from sklearn.utils.class_weight import compute_class_weight
from scvi import REGISTRY_KEYS
from scvi.data import AnnDataManager
from scvi.data.fields import (
CategoricalJointObsField,
CategoricalObsField,
LayerField,
NumericalJointObsField,
NumericalObsField,
)
from scvi.model.base import UnsupervisedTrainingMixin
from scvi.utils import setup_anndata_dsp
from scvi.dataloaders._data_splitting import DataSplitter
from scvi.dataloaders._ann_dataloader import AnnDataLoader
from scvi.train import TrainRunner
from scvi.model.base import RNASeqMixin, VAEMixin, BaseModelClass
from scvi.autotune._types import Tunable, TunableMixin
logger = logging.getLogger(__name__)
from ._module import CellDISECTModule, PerturbationEmbedding
from .data import AnnDataSplitter
from .trainingplan import CellDISECTTrainingPlan
from .utils import validate_perturbation_embeddings, parse_perturbation
from scvi.train._callbacks import SaveBestState
import torch.nn as nn
[docs]
class CellDISECT(
RNASeqMixin,
VAEMixin,
UnsupervisedTrainingMixin,
BaseModelClass,
TunableMixin
):
"""CellDISECT model for single-cell RNA sequencing data analysis.
Parameters
----------
adata
AnnData object that has been registered via :meth:`~scvi.model.SCVI.setup_anndata`.
n_hidden
Number of nodes per hidden layer.
n_latent_shared
Dimensionality of the shared latent space.
n_latent_attribute
Dimensionality of the latent space for each sensitive attribute.
n_layers
Number of hidden layers used for encoder and decoder neural networks.
dropout_rate
Dropout rate for neural networks.
gene_likelihood
One of:
* ``'nb'`` - Negative binomial distribution
* ``'zinb'`` - Zero-inflated negative binomial distribution
* ``'poisson'`` - Poisson distribution
latent_distribution
One of:
* ``'normal'`` - Normal distribution
* ``'ln'`` - Logistic normal distribution (Normal(0, I) transformed by softmax)
split_key
Key in `adata.obs` to split the data into training, validation, and test sets.
train_split
Values in `split_key` to be used for training.
valid_split
Values in `split_key` to be used for validation.
test_split
Values in `split_key` to be used for testing.
weighted_classifier
Whether to use weighted classifiers for categorical covariates.
**model_kwargs
Additional keyword arguments for the model.
"""
_module_cls = CellDISECTModule
_data_splitter_cls = AnnDataSplitter
_training_plan_cls = CellDISECTTrainingPlan
_train_runner_cls = TrainRunner
[docs]
def __init__(
self,
adata: AnnData,
n_hidden: int = 128,
n_latent_shared: int = 10,
n_latent_attribute: int = 10,
n_layers: int = 1,
dropout_rate: float = 0.1,
gene_likelihood: Literal["zinb", "nb", "poisson"] = "zinb",
latent_distribution: Literal["normal", "ln"] = "normal",
split_key: str = None,
train_split: Union[str, List[str]] = ["train"],
valid_split: Union[str, List[str]] = ["valid"],
test_split: Union[str, List[str]] = ["ood"],
weighted_classifier=False,
use_bias: bool = True,
**model_kwargs,
):
"""
Initialize the CellDISECT model.
Parameters
----------
adata : AnnData
AnnData object that has been registered via :meth:`~scvi.model.SCVI.setup_anndata`.
n_hidden : int, optional
Number of nodes per hidden layer, by default 128.
n_latent_shared : int, optional
Dimensionality of the shared latent space, by default 10.
n_latent_attribute : int, optional
Dimensionality of the latent space for each sensitive attribute, by default 10.
n_layers : int, optional
Number of hidden layers used for encoder and decoder neural networks, by default 1.
dropout_rate : float, optional
Dropout rate for neural networks, by default 0.1.
gene_likelihood : Literal["zinb", "nb", "poisson"], optional
Gene likelihood distribution, by default "zinb".
latent_distribution : Literal["normal", "ln"], optional
Latent distribution, by default "normal".
split_key : str, optional
Key in `adata.obs` to split the data into training, validation, and test sets, by default None.
train_split : Union[str, List[str]], optional
Values in `split_key` to be used for training, by default ["train"].
valid_split : Union[str, List[str]], optional
Values in `split_key` to be used for validation, by default ["valid"].
test_split : Union[str, List[str]], optional
Values in `split_key` to be used for testing, by default ["ood"].
weighted_classifier : bool, optional
Whether to use weighted classifiers for categorical covariates, by default False.
use_bias : bool, optional
Whether to use bias terms in the neural networks, by default True.
**model_kwargs : dict
Additional keyword arguments for the model.
"""
super().__init__(adata)
self._data_loader_cls = AnnDataLoader
self.split_key = split_key
n_cats_per_cov = (
self.adata_manager.get_state_registry(
REGISTRY_KEYS.CAT_COVS_KEY
).n_cats_per_key
if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry
else None
)
self.classifier_weights = None
if weighted_classifier:
if REGISTRY_KEYS.CAT_COVS_KEY not in self.adata_manager.data_registry:
raise ValueError(
"Cannot use weighted classifier without categorical covariates."
)
self.classifier_weights = []
for covar in self.adata_manager.get_state_registry(
REGISTRY_KEYS.CAT_COVS_KEY
).field_keys:
y = self.adata.obs[covar].values
classes = np.unique(y)
class_weight = compute_class_weight(class_weight="balanced", classes=classes, y=y)
self.classifier_weights.append(class_weight)
# Extract perturbation configuration if it was set during setup_anndata
pert_config = adata.uns.get('_celldisect_perturbation_config', None)
self._pert_config = pert_config
if pert_config is not None:
pert_cov_idx = pert_config['perturbation_cov_idx']
pert_emb_key = pert_config['perturbation_embedding_key']
predefined_embs = adata.uns[pert_emb_key]
cat_covs_registry = self.adata_manager.get_state_registry(
REGISTRY_KEYS.CAT_COVS_KEY
)
pert_category_names = list(
cat_covs_registry.mappings[pert_config['perturbation_key']]
)
delimiter = pert_config['perturbation_combination_delimiter']
model_kwargs['perturbation_cov_idx'] = pert_cov_idx
model_kwargs['predefined_pert_embeddings'] = predefined_embs
model_kwargs['perturbation_category_names'] = pert_category_names
model_kwargs['perturbation_combination_delimiter'] = delimiter
self.module = self._module_cls(
n_input=self.summary_stats.n_vars,
n_cats_per_cov=n_cats_per_cov,
n_hidden=n_hidden,
n_latent_shared=n_latent_shared,
n_latent_attribute=n_latent_attribute,
n_layers=n_layers,
dropout_rate=dropout_rate,
gene_likelihood=gene_likelihood,
latent_distribution=latent_distribution,
classifier_weights=self.classifier_weights,
bias=use_bias,
**model_kwargs,
)
if split_key is not None:
train_indices = np.where(adata.obs.loc[:, split_key].isin(train_split))[0]
valid_indices = np.where(adata.obs.loc[:, split_key].isin(valid_split))[0]
test_indices = np.where(adata.obs.loc[:, split_key].isin(test_split))[0]
self.train_indices = train_indices
self.valid_indices = valid_indices
self.test_indices = test_indices
self._model_summary_string = (
"CellDISECT Model with the following params: \nn_hidden: {}, n_latent_shared: {}, n_latent_attribute: {}"
", n_layers: {}, dropout_rate: {}, gene_likelihood: {}, latent_distribution: {}"
).format(
n_hidden,
n_latent_shared,
n_latent_attribute,
n_layers,
dropout_rate,
gene_likelihood,
latent_distribution,
)
self.init_params_ = self._get_init_params(locals())
@classmethod
@setup_anndata_dsp.dedent
def setup_anndata(
cls,
adata: AnnData,
layer: Optional[str] = None,
batch_key: Optional[str] = None,
labels_key: Optional[str] = None,
size_factor_key: Optional[str] = None,
categorical_covariate_keys: Optional[List[str]] = None,
continuous_covariate_keys: Optional[List[str]] = None,
add_cluster_covariate: bool = False,
clustering_normalize_counts: bool = True,
perturbation_key: Optional[str] = None,
perturbation_embedding_key: Optional[str] = None,
perturbation_combination_delimiter: str = "+",
**kwargs,
):
"""
Set up the AnnData object for the CellDISECT model.
This method configures the AnnData object by registering the necessary
fields and optionally adding a cluster covariate. When
``perturbation_key`` is provided, the corresponding column in
``adata.obs`` is treated as a perturbation covariate whose embeddings
come from ``adata.uns[perturbation_embedding_key]`` rather than being
learned during training.
Parameters
----------
adata : AnnData
AnnData object to be set up.
layer : Optional[str], optional
Layer in `adata` to use as the count data, by default None.
batch_key : Optional[str], optional
Key in `adata.obs` for batch information, by default None.
labels_key : Optional[str], optional
Key in `adata.obs` for labels, by default None.
size_factor_key : Optional[str], optional
Key in `adata.obs` for size factors, by default None.
categorical_covariate_keys : Optional[List[str]], optional
List of keys in `adata.obs` for categorical covariates, by default None.
continuous_covariate_keys : Optional[List[str]], optional
List of keys in `adata.obs` for continuous covariates, by default None.
add_cluster_covariate : bool, optional
Whether to add a cluster covariate to `adata.obs`, by default False.
clustering_normalize_counts : bool, optional
Whether to normalize counts before clustering, by default True.
perturbation_key : Optional[str], optional
Column in ``adata.obs`` that contains perturbation labels (e.g.
``"GeneA"``, ``"GeneA+GeneB"``). When set, the perturbation
covariate uses predefined embeddings instead of learned ones.
perturbation_embedding_key : Optional[str], optional
Key in ``adata.uns`` whose value is a ``dict[str, np.ndarray]``
mapping atomic perturbation names to their vector
representations (e.g. ESM or GenePT embeddings). Required when
``perturbation_key`` is set.
perturbation_combination_delimiter : str, optional
Delimiter for combinatorial perturbation labels, by default ``"+"``.
**kwargs
Additional keyword arguments.
Returns
-------
None
"""
# Handle perturbation covariate
if perturbation_key is not None:
if perturbation_embedding_key is None:
raise ValueError(
"When `perturbation_key` is set, `perturbation_embedding_key` "
"must also be provided."
)
validate_perturbation_embeddings(
adata, perturbation_key, perturbation_embedding_key,
perturbation_combination_delimiter,
)
if categorical_covariate_keys is None:
categorical_covariate_keys = []
if perturbation_key not in categorical_covariate_keys:
categorical_covariate_keys = list(categorical_covariate_keys)
categorical_covariate_keys.append(perturbation_key)
logger.info(
f"Auto-added perturbation_key '{perturbation_key}' to "
f"categorical_covariate_keys."
)
pert_cov_idx = categorical_covariate_keys.index(perturbation_key)
adata.uns['_celldisect_perturbation_config'] = {
'perturbation_key': perturbation_key,
'perturbation_embedding_key': perturbation_embedding_key,
'perturbation_combination_delimiter': perturbation_combination_delimiter,
'perturbation_cov_idx': pert_cov_idx,
}
setup_method_args = cls._get_setup_method_args(**locals())
if add_cluster_covariate:
cls.add_cluster_covariate(
adata,
normalize_counts=clustering_normalize_counts
)
anndata_fields = [
LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key),
NumericalObsField(
REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False
),
CategoricalJointObsField(
REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys
),
NumericalJointObsField(
REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys
),
]
if add_cluster_covariate:
anndata_fields.append(
CategoricalObsField('cluster', '_cluster')
)
adata_manager = AnnDataManager(
fields=anndata_fields, setup_method_args=setup_method_args
)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)
[docs]
@classmethod
def add_cluster_covariate(
cls,
adata: AnnData,
normalize_counts: bool = True):
"""
Run PCA on the gene expression matrix and run Leiden clustering on the PCA components
to create a cluster covariate to be added to the `adata.obs`.
Parameters
----------
adata : AnnData
AnnData object containing the single-cell RNA sequencing data.
normalize_counts : bool, optional
If True, takes the counts from the `adata.layers['counts']` and log normalizes them, by default True.
Returns
-------
None
"""
logger.info("Adding cluster covariate to adata.obs")
if '_cluster' in adata.obs.keys():
logger.warning(
"Cluster covariate already present in adata.obs, remove in case you want to re-run, skipping!")
return
if normalize_counts:
logger.info("Normalizing counts")
adata.X = adata.layers['counts'].copy()
# Normalizing to median total counts
sc.pp.normalize_total(adata)
# Logarithmize the data
sc.pp.log1p(adata)
logger.info("Running PCA and Leiden clustering")
sc.tl.pca(adata, random_state=0)
sc.pp.neighbors(adata, use_rep='X_pca', random_state=0)
sc.tl.leiden(adata, key_added='_cluster', flavor='igraph', n_iterations=2, random_state=0)
return
# call this method after training the model with this held-out:
# covs[cov_idx] = cov_value_cf, covs[others_idx] = adata.obs[others_idx]
@torch.no_grad()
def predict_given_covs_depricated(
self,
adata: AnnData, # source anndata with fixed cov values
cats: List[str],
cov_idx: int, # index in cats starting from 0
cov_value_cf,
batch_size: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
self._check_if_trained(warn=False)
adata_cf = adata.copy()
cov_name = cats[cov_idx]
adata_cf.obs[cov_name] = pd.Categorical([cov_value_cf for _ in adata_cf.obs[cov_name]])
CellDISECT.setup_anndata(
adata_cf,
layer='counts',
categorical_covariate_keys=cats,
continuous_covariate_keys=[]
)
adata_cf = self._validate_anndata(adata_cf)
scdl = self._make_data_loader(
adata=adata_cf, batch_size=batch_size
)
px_cf_mean_list = []
for tensors in scdl:
px_cf = self.module.sub_forward(idx=cov_idx + 1, x=tensors[REGISTRY_KEYS.X_KEY].to(self.device),
cat_covs=tensors[REGISTRY_KEYS.CAT_COVS_KEY].to(self.device))
px_cf_mean_list.append(px_cf.mean)
px_cf_mean_tensor = torch.cat(px_cf_mean_list, dim=0)
px_cf_mean_pred = torch.mean(px_cf_mean_tensor, dim=0)
px_cf_variance = torch.sub(px_cf_mean_tensor, px_cf_mean_pred)
px_cf_variance = torch.pow(px_cf_variance, 2)
px_cf_variance_pred = torch.mean(px_cf_variance, dim=0)
return px_cf_mean_pred, px_cf_variance_pred
@torch.no_grad()
def predict_counterfactuals(
self,
adata: AnnData,
cov_names: list[str],
cov_values: list[str],
cov_values_cf: list[str],
cats: list[str],
n_samples_from_source: Optional[int] = None,
seed: Optional[int] = 0
):
"""Predicts counterfactuals for a given subset of data.
This function estimates the counterfactual outcomes for a subset of data based on specified changes in covariate values.
Parameters
----------
adata : AnnData
The subset of the data for which the counterfactuals are to be predicted.
cov_names : list[str]
Names of the covariates that are to be changed.
cov_values : list[str]
Original values for the covariates that are to be changed.
cov_values_cf : list[str]
Counterfactual values for the covariates that are to be changed.
cats : list[str]
Names of the categorical covariates.
n_samples_from_source : Optional[int], optional
Number of samples to take from the source data to predict the counterfactuals.
If None, all samples from the source data are used. Defaults to None.
seed : Optional[int], optional
Random seed for reproducibility. Defaults to 0.
Only used if `n_samples_from_source` is not None.
Returns
-------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
Control (source), True Counterfactuals, and Predicted Counterfactual COUNTS (not log-transformed).
Examples
--------
Single covariate change::
cats = ['cell_type', 'condition']
cell_type_to_check = ['CD4 T',]
cov_names = ['condition']
cov_values = ['ctrl']
cov_values_cf = ['stimulated']
n_samples_from_source = 500
x_ctrl, x_true, x_pred = model.predict_counterfactuals(
adata[(adata.obs['cell_type'].isin(cell_type_to_check))].copy(),
cov_names=cov_names,
cov_values=cov_values,
cov_values_cf=cov_values_cf,
cats=cats,
n_samples_from_source=n_samples_from_source,
)
Multiple covariate change::
cats = ['tissue', 'Sample ID', 'sex', 'Age_bin', 'CoarseCellType']
cell_type_to_check = 'Epithelial cell (luminal)'
cov_names = ['sex', 'tissue']
cov_values = ['female', 'breast']
cov_values_cf = ['male', 'prostate gland']
n_samples_from_source = None
x_ctrl, x_true, x_pred = model.predict_counterfactuals(
adata[adata.obs['Broad cell type'] == cell_type_to_check].copy(),
cov_names=cov_names,
cov_values=cov_values,
cov_values_cf=cov_values_cf,
cats=cats,
n_samples_from_source=n_samples_from_source,
)
"""
# Copy the counts layer to the main matrix
adata.X = adata.layers['counts'].copy()
adata.obs['idx'] = [i for i in range(len(adata))]
# Identify true and source indices based on covariate values
true_indices = pd.DataFrame(
[adata.obs[cov_name] == cov_values_cf[i] for i, cov_name in enumerate(cov_names)]
).all(0).values
true_idx = list(adata[true_indices].obs['idx'])
source_indices = pd.DataFrame(
[adata.obs[cov_name] == cov_values[i] for i, cov_name in enumerate(cov_names)]
).all(0).values
source_idx = list(adata[source_indices].obs['idx'])
# Create true and source AnnData objects
true_adata = adata[adata.obs['idx'].isin(true_idx)].copy()
source_adata = adata[adata.obs['idx'].isin(source_idx)].copy()
# Sample from source data if specified
if n_samples_from_source is not None:
random.seed(seed)
chosen_ids = random.sample(range(len(source_adata)), n_samples_from_source)
source_adata = source_adata[chosen_ids].copy()
adata_cf = source_adata.copy()
# Update covariate values in the counterfactual data
for i, cov_name in enumerate(cov_names):
adata_cf.obs.loc[:, cov_name] = pd.Categorical(
[cov_values_cf[i] for _ in adata_cf.obs[cov_name]]
)
batch_size = len(adata_cf)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Setup AnnData for the counterfactual data
self.setup_anndata(
adata_cf,
layer='counts',
categorical_covariate_keys=cats,
continuous_covariate_keys=[]
)
adata_cf = self._validate_anndata(adata_cf)
source_adata = self._validate_anndata(source_adata)
# Create data loaders for source and counterfactual data
scdl_cf = self._make_data_loader(
adata=adata_cf, batch_size=batch_size
)
scdl = self._make_data_loader(
adata=source_adata, batch_size=batch_size
)
# Predict counterfactuals
px_cf_mean_list = []
for tensors, tensors_cf in zip(scdl, scdl_cf):
_, pxs_cf = self.module.sub_forward_cf_avg(
x=tensors[REGISTRY_KEYS.X_KEY].to(device),
cat_covs=tensors[REGISTRY_KEYS.CAT_COVS_KEY].to(device),
cat_covs_cf=tensors_cf[REGISTRY_KEYS.CAT_COVS_KEY].to(device)
)
for px_cf in pxs_cf:
if px_cf is None:
continue
x_cf = px_cf.mu
px_cf_mean_list.append(x_cf)
# Compute mean predictions
px_cf_mean_tensor = torch.stack(px_cf_mean_list, dim=0)
px_cf_mean_pred = torch.mean(px_cf_mean_tensor, dim=0) # (n_cells, n_genes)
# Convert predictions to numpy arrays
px_cf_mean_pred = px_cf_mean_pred.to('cpu').detach().numpy()
px_cf_mean_tensor = px_cf_mean_tensor.to('cpu').detach().numpy()
# Create AnnData object for predictions
px_cf_mean_tensor = ad.AnnData(px_cf_mean_pred)
px_cf_mean_tensor = torch.tensor(px_cf_mean_tensor.X)
# Get true and control counts
if sparse.issparse(true_adata.X):
true_x_count = torch.tensor(true_adata.X.toarray())
else:
true_x_count = torch.tensor(true_adata.X)
if sparse.issparse(source_adata.X):
cf_x_count = torch.tensor(source_adata.X.toarray())
else:
cf_x_count = torch.tensor(source_adata.X)
x_true = true_x_count
x_pred = px_cf_mean_tensor
x_ctrl = cf_x_count
return x_ctrl, x_true, x_pred
@torch.no_grad()
def get_latent_representation(
self,
adata: Optional[AnnData] = None,
indices: Optional[Sequence[int]] = None,
batch_size: Optional[int] = None,
nullify_cat_covs_indices: Optional[List[int]] = None,
nullify_shared: Optional[bool] = False
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
"""
Get the latent representation of the data.
This function computes the latent representation of the data using the trained model.
It allows for optional nullification of specific categorical covariates or the shared latent space.
Parameters
----------
adata : Optional[AnnData]
Annotated data object. If None, uses the data registered with the model.
indices : Optional[Sequence[int]]
Optional indices to subset the data.
batch_size : Optional[int]
Batch size to use for data loading. If None, uses the default batch size.
nullify_cat_covs_indices : Optional[List[int]]
List of indices of categorical covariates to nullify in the latent space. If None, no covariates are nullified.
nullify_shared : Optional[bool]
If True, nullifies the shared latent space. Defaults to False.
Returns
-------
Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]
The latent representation of the data. If `nullify_cat_covs_indices` or `nullify_shared` is specified,
returns a tuple of the latent representation and the nullified latent representation.
"""
self._check_if_trained(warn=False)
adata = self._validate_anndata(adata)
scdl = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)
latent = []
for tensors in scdl:
inference_inputs = self.module._get_inference_input(tensors)
outputs = self.module.inference(**inference_inputs,
nullify_cat_covs_indices=nullify_cat_covs_indices,
nullify_shared=nullify_shared)
latent += [outputs["z_concat"].cpu()]
return torch.cat(latent).numpy()
@torch.no_grad()
def get_cat_covariate_latents(
self,
):
"""
Returns the embeddings of the categorical covariates.
This function retrieves the embeddings of the categorical covariates from the trained model.
It also returns the mappings of the categorical covariates.
Parameters
----------
self : object
The instance of the class.
Returns
-------
Tuple[dict, dict]
A tuple containing two dictionaries:
- covar_embeddings: Dictionary where keys are covariate names and values are the embeddings as numpy arrays.
- covar_mappings: Dictionary where keys are covariate names and values are the mappings of the covariates.
"""
self._check_if_trained(warn=False)
covar_names = self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).values()[0]
covar_embeddings = {}
covar_mappings = {}
for name, emb in zip(covar_names, self.module.covars_embeddings.values()):
mappings = self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY)['mappings'][name]
covar_embeddings[name] = emb.weight.cpu().detach().numpy()
covar_mappings[name] = mappings
return covar_embeddings, covar_mappings
@torch.no_grad()
def get_gene_importance(self) -> pd.DataFrame:
"""
Computes gene importance for each encoder.
This method calculates a gene importance score for each gene based on the weights
of the encoders. The importance is calculated by propagating the weights through
the network layers to obtain an effective weight matrix from input genes to the
latent space. This method is most accurate when the model is trained without biases
and with a non-negative activation function like ReLU. If biases are present,
they are ignored in this calculation.
Returns
-------
pd.DataFrame
A pandas DataFrame containing gene importance scores. The rows correspond to
genes, and the columns correspond to the different encoders (e.g., 'shared',
'attribute_0', 'attribute_1', ...).
"""
self._check_if_trained(warn=False)
self.module.eval()
n_genes = self.summary_stats.n_vars
gene_names = self.adata_manager.adata.var_names.to_list()
importances_df = pd.DataFrame(index=gene_names)
encoder_names = ["shared"] + [f"attribute_{i}" for i in range(self.module.zs_num)]
for i, encoder in enumerate(self.module.z_encoders_list):
sequences = [seq for seq in encoder.encoder.fc_layers]
linear_layers = []
for seq in sequences:
for layer in seq:
if isinstance(layer, nn.Linear):
linear_layers.append(layer)
# Add the mean encoder layer
linear_layers.append(encoder.mean_encoder)
# Get weights and compute effective weight matrix
# Note: PyTorch Linear layer weights are (out_features, in_features)
# We transpose them to (in_features, out_features) for matrix multiplication
w1 = linear_layers[0].weight.t()
# The input to the encoder is genes + covariates, we only want the gene part
w_eff = w1[:n_genes, :]
for layer in linear_layers[1:]:
w = layer.weight.t()
w_eff = torch.relu(w_eff) @ w
# Calculate importance scores
gene_importances = torch.sum(torch.abs(w_eff), dim=1)
normalized_importances = gene_importances / torch.sum(gene_importances)
importances_df[encoder_names[i]] = normalized_importances.cpu().numpy()
return importances_df
[docs]
def train(
self,
max_epochs: Optional[int] = None,
use_gpu: Optional[Union[str, int, bool]] = True,
train_size: float = 0.8,
validation_size: Optional[float] = None,
batch_size: int = 256,
early_stopping: bool = True,
save_best: bool = False,
plan_kwargs: Optional[dict] = None,
recon_weight: Tunable[Union[float, int]] = 10, # RECONST_LOSS_X weight
cf_weight: Tunable[Union[float, int]] = 1, # RECONST_LOSS_X_CF weight
beta: Tunable[Union[float, int]] = 1, # KL Zi weight
clf_weight: Tunable[Union[float, int]] = 50, # Si classifier weight
adv_clf_weight: Tunable[Union[float, int]] = 10, # adversarial classifier weight
adv_period: Tunable[int] = 1, # adversarial training period
n_cf: Tunable[int] = 10, # number of X_cf recons (a random permutation of n VAEs and a random half-batch subset for each trial)
kappa_optimizer2: bool = True,
n_epochs_pretrain_ae: int = 0,
**trainer_kwargs,
):
"""
Train the model.
Parameters
----------
max_epochs : Optional[int]
Number of passes through the dataset. If `None`, defaults to `np.min([round((20000 / n_cells) * 400), 400])`.
use_gpu : Optional[Union[str, int, bool]]
Whether to use GPU for training. Can be a boolean, string, or integer specifying the GPU device.
train_size : float
Size of the training set in the range [0.0, 1.0].
validation_size : Optional[float]
Size of the validation set. If `None`, defaults to 1 - `train_size`. If `train_size + validation_size < 1`, the remaining cells belong to a test set.
batch_size : int
Minibatch size to use during training.
early_stopping : bool
Perform early stopping. Additional arguments can be passed in `**kwargs`. See :class:`~scvi.train.Trainer` for further options.
save_best : bool
Save the best model state with respect to the validation loss (default), or use the final state in the training procedure.
plan_kwargs : Optional[dict]
Keyword arguments for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to `train()` will overwrite values present in `plan_kwargs`, when appropriate.
recon_weight : Tunable[Union[float, int]]
Weight for the reconstruction loss of X.
cf_weight : Tunable[Union[float, int]]
Weight for the reconstruction loss of X_cf.
beta : Tunable[Union[float, int]]
Weight for the KL divergence of Zi.
clf_weight : Tunable[Union[float, int]]
Weight for the Si classifier loss.
adv_clf_weight : Tunable[Union[float, int]]
Weight for the adversarial classifier loss.
adv_period : Tunable[int]
Adversarial training period.
n_cf : Tunable[int]
Number of X_cf reconstructions (a random permutation of n VAEs and a random half-batch subset for each trial).
kappa_optimizer2 : bool
Whether to use the second kappa optimizer.
n_epochs_pretrain_ae : int
Number of epochs to pretrain the autoencoder.
**trainer_kwargs
Other keyword arguments for :class:`~scvi.train.Trainer`.
Returns
-------
None
"""
n_cells = self.adata.n_obs
if max_epochs is None:
max_epochs = int(np.min([round((20000 / n_cells) * 400), 400]))
plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else {}
if self.split_key is not None:
data_splitter = AnnDataSplitter(
self.adata_manager,
train_indices=self.train_indices,
valid_indices=self.valid_indices,
test_indices=self.test_indices,
batch_size=batch_size,
use_gpu=use_gpu,
drop_last=3,
)
else:
data_splitter = DataSplitter(
adata_manager=self.adata_manager,
train_size=train_size,
validation_size=validation_size,
batch_size=batch_size,
)
training_plan = self._training_plan_cls(self.module,
recon_weight=recon_weight,
cf_weight=cf_weight,
beta=beta,
clf_weight=clf_weight,
adv_clf_weight=adv_clf_weight,
adv_period=adv_period,
n_cf=n_cf,
kappa_optimizer2=kappa_optimizer2,
n_epochs_pretrain_ae=n_epochs_pretrain_ae,
**plan_kwargs)
es = "early_stopping"
trainer_kwargs[es] = (
early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es]
)
if save_best:
checkpoint = SaveBestState(
monitor="loss_validation", mode="min", period=1, verbose=True
)
trainer_kwargs["callbacks"] = [] if "callbacks" not in trainer_kwargs else trainer_kwargs["callbacks"]
trainer_kwargs["callbacks"].append(checkpoint)
trainer_kwargs['enable_checkpointing'] = True
trainer_kwargs['early_stopping_monitor'] = "loss_validation"
runner = self._train_runner_cls(
self,
training_plan=training_plan,
data_splitter=data_splitter,
max_epochs=max_epochs,
use_gpu=use_gpu,
**trainer_kwargs,
)
return runner()
@torch.no_grad()
def generate(
self,
n_samples: int,
covariates: dict[str, str],
library_size: Optional[float] = None,
adata_for_library_size: Optional[AnnData] = None,
random_seed: Optional[int] = None,
include_shared_latent: bool = False,
):
"""
Generate new cells conditionally based on specified categorical covariates.
This method generates `n_samples` new cells for a given combination of categorical covariates.
The generation process involves sampling from the latent spaces and then decoding to
get gene expression profiles.
The shared latent space (`z_shared`) is sampled from a standard normal distribution.
The attribute-specific latent spaces (`zs`) are sampled from their respective prior encoders,
conditioned on the provided covariates.
Parameters
----------
n_samples
Number of cells to generate.
covariates
A dictionary mapping categorical covariate names to the desired values for generation.
All categorical covariates used to train the model must be specified.
library_size
A specific library size to use for generation.
Mutually exclusive with `adata_for_library_size`.
adata_for_library_size
An AnnData object to infer the median library size from. If cells with the specified
`covariates` combination exist in this AnnData object, their median library size is used.
Otherwise, the median library size of the entire `adata_for_library_size` is used.
Mutually exclusive with `library_size`.
random_seed
A random seed for reproducibility.
include_shared_latent
Whether to include the expression from the decoder associated with the shared latent space (Z_0)
in the `average_expression`. If `False`, the average is computed over attribute-specific
decoders only. We suggest setting this to `False` for most use cases. The shared decoder
is assumed to be `decoder_0`.
Returns
-------
A tuple containing:
- **average_expression** (numpy.ndarray): An array of shape `(n_samples, n_vars)`
representing the average gene expression across the selected decoders.
- **all_expressions** (dict[str, numpy.ndarray]): A dictionary where keys are
decoder names (`decoder_0`, `decoder_1`, etc.) and values are numpy arrays of
shape `(n_samples, n_vars)` representing the generated gene expression counts
from each individual decoder.
Examples
--------
>>> n_generated_cells = 10
>>> desired_covariates = {"cell_type": "T-cell", "condition": "diseased"}
>>> avg_expr, all_expr = model.generate(
... n_samples=n_generated_cells,
... covariates=desired_covariates,
... adata_for_library_size=adata
... )
To compute a custom average using only a subset of decoders (e.g., decoder_0 and decoder_2):
>>> import numpy as np
>>> custom_avg = np.mean(
... [all_expr["decoder_0"], all_expr["decoder_2"]], axis=0
... )
"""
if library_size is None and adata_for_library_size is None:
raise ValueError("Either library_size or adata_for_library_size must be provided.")
if adata_for_library_size is not None:
adata_for_library_size = adata_for_library_size.copy()
self._check_if_trained(warn=False)
self.module.eval()
if random_seed is not None:
random.seed(random_seed)
torch.manual_seed(random_seed)
np.random.seed(random_seed)
cat_covs_registry = self.adata_manager.get_state_registry(
REGISTRY_KEYS.CAT_COVS_KEY
)
cov_names = cat_covs_registry.field_keys
mappings = cat_covs_registry.mappings
if set(covariates.keys()) != set(cov_names):
raise ValueError(
"Please provide a value for every categorical covariate. "
f"Required: {cov_names}, but got: {list(covariates.keys())}"
)
cat_indices = []
for cov_name in cov_names:
cov_value = covariates[cov_name]
mapping = mappings[cov_name]
if cov_value not in mapping:
raise ValueError(
f"Covariate value '{cov_value}' not found for covariate '{cov_name}'. "
f"Available values are: {list(mapping)}"
)
value_idx = np.where(mapping == cov_value)[0][0]
cat_indices.append(value_idx)
cat_covs_tensor = (
torch.tensor(cat_indices, device=self.device).unsqueeze(0).expand(n_samples, -1)
)
# 1. Sample z_shared from a standard Gaussian
z_shared = torch.randn(
n_samples, self.module.n_latent_shared, device=self.device
)
# 2. Sample zs from prior encoders
prior_emb_in = []
for i, embedding_indices in enumerate(cat_covs_tensor.t()):
emb = self.module.covars_embeddings[str(i)](embedding_indices.long())
emb = emb + torch.randn_like(emb) * 0.1
prior_emb_in.append(emb)
zs = []
for i in range(len(self.module.z_prior_encoders_list)):
# Encoder returns (distribution, sample)
_, z_s_i = self.module.z_prior_encoders_list[i](prior_emb_in[i])
zs.append(z_s_i)
# 3. Decode
if library_size is not None:
library = torch.full(
(n_samples, 1), library_size
).to(self.device)
elif adata_for_library_size is not None:
# If there are any cells for the covariate combination, use the library size of the subset, if not, use the library size of the entire dataset
cov_subset_mask = []
for cov_name, cov_value in covariates.items():
cov_subset_mask.append(adata_for_library_size.obs[cov_name] == cov_value)
cov_subset_mask = np.all(cov_subset_mask, axis=0)
if np.any(cov_subset_mask):
print(f"Using library size of the subset for the covariate combination {covariates}.")
print(f"Number of cells in the subset: {np.sum(cov_subset_mask)}")
adata_for_library_size = adata_for_library_size[cov_subset_mask]
else:
print(f"No cells found for the covariate combination {covariates} in the adata_for_library_size dataset.")
print(f"Using library size of the entire dataset.")
library_global = torch.log(torch.tensor(adata_for_library_size.X.sum(1)).unsqueeze(1))
library_median = library_global.median().item()
print(f"Library size median: {library_median}")
library = torch.full(
(n_samples, 1), library_median
).to(self.device)
else:
raise ValueError("Either library_size or adata_for_library_size must be provided.")
generative_kwargs = {
"z_shared": z_shared,
"zs": zs,
"library": library,
"cat_covs": cat_covs_tensor,
}
generative_outputs = self.module.generative(**generative_kwargs)
pxs = generative_outputs["px"]
# 4. Return generated expressions from all decoders
all_expressions = {}
for i, px in enumerate(pxs):
all_expressions[f"decoder_{i}"] = px.mu.cpu().detach().numpy()
# 5. Compute average expression
if include_shared_latent:
average_expression = np.mean(list(all_expressions.values()), axis=0)
else:
average_expression = np.mean(list(all_expressions.values())[1:], axis=0)
return average_expression, all_expressions
# ------------------------------------------------------------------
# Perturbation prediction API
# ------------------------------------------------------------------
[docs]
def add_perturbation_embedding(
self,
name: str,
embedding: np.ndarray,
) -> None:
"""Register an embedding vector for an unseen atomic perturbation.
After calling this method the model can predict counterfactuals for
the new perturbation (or combinatorial perturbations that include it).
Parameters
----------
name
Atomic perturbation name (e.g. ``"GeneC"``).
embedding
Vector representation with the same dimensionality as the
predefined embeddings used during training.
"""
if self._pert_config is None:
raise RuntimeError(
"This model was not set up with perturbation support. "
"Use perturbation_key in setup_anndata first."
)
pert_idx = self._pert_config['perturbation_cov_idx']
pert_emb_module = self.module.covars_embeddings[str(pert_idx)]
if not isinstance(pert_emb_module, PerturbationEmbedding):
raise RuntimeError("Perturbation embedding module not found.")
pert_emb_module.add_perturbation(name, np.asarray(embedding, dtype=np.float32))
[docs]
def get_perturbation_embeddings(self) -> Tuple[List[str], np.ndarray]:
"""Return current perturbation names and their (combined) embeddings.
Returns
-------
names
Ordered list of perturbation category names.
embeddings
Numpy array of shape ``(n_perturbations, emb_dim)``.
"""
if self._pert_config is None:
raise RuntimeError(
"This model was not set up with perturbation support."
)
pert_idx = self._pert_config['perturbation_cov_idx']
pert_emb_module = self.module.covars_embeddings[str(pert_idx)]
names = list(pert_emb_module._category_names)
weights = pert_emb_module.embedding.weight.data.cpu().numpy()
return names, weights
# ------------------------------------------------------------------
# Perturbation prediction API
# ------------------------------------------------------------------
@torch.no_grad()
def predict_perturbation(
self,
adata: AnnData,
perturbation: str,
source_perturbation: str,
cats: List[str],
perturbation_key: str,
new_embeddings: Optional[dict] = None,
n_samples_from_source: Optional[int] = None,
seed: Optional[int] = 0,
device: Optional[Union[str, torch.device]] = None,
batch_size: Optional[int] = 256,
source_adata: Optional[AnnData] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
"""Predict gene expression after a perturbation.
Memory-efficient implementation that avoids copying large adata objects.
Parameters
----------
adata
AnnData object containing cells. Must have a ``'counts'`` layer.
perturbation
Target perturbation label (e.g. ``"GeneA"`` or ``"GeneA+GeneB"``).
source_perturbation
Source / control perturbation label (e.g. ``"ctrl"``).
cats
List of categorical covariate keys (same order used during training).
perturbation_key
Which element of ``cats`` is the perturbation covariate.
new_embeddings
Optional dictionary mapping atomic perturbation names to embedding vectors.
n_samples_from_source
If set, randomly sample this many cells from the source group.
seed
Random seed for reproducibility.
device
Device to run prediction on. If ``None``, uses the model's device.
batch_size
Batch size for forward passes. Default 256.
source_adata
Optional AnnData with control cells.
Returns
-------
x_ctrl
Source (control) counts, shape ``(n_source, n_genes)``.
x_true
Target (ground-truth) counts if cells exist, else ``None``.
x_pred
Predicted counts, shape ``(n_source, n_genes)``.
"""
self._check_if_trained(warn=False)
if self._pert_config is None:
raise RuntimeError(
"This model was not set up with perturbation support. "
"Use perturbation_key in setup_anndata first."
)
pert_idx = self._pert_config['perturbation_cov_idx']
pert_emb_module = self.module.covars_embeddings[str(pert_idx)]
delimiter = self._pert_config['perturbation_combination_delimiter']
# Register new embeddings if provided
if new_embeddings is not None:
for name, vec in new_embeddings.items():
if name not in pert_emb_module._predefined_embeddings:
pert_emb_module._predefined_embeddings[name] = np.asarray(vec, dtype=np.float32)
# Validate all perturbation components have embeddings
for comp in parse_perturbation(perturbation, delimiter):
if comp not in pert_emb_module._predefined_embeddings:
raise ValueError(f"Atomic perturbation '{comp}' has no embedding.")
for comp in parse_perturbation(source_perturbation, delimiter):
if comp not in pert_emb_module._predefined_embeddings:
raise ValueError(f"Source perturbation component '{comp}' has no embedding.")
# Determine device
_device = torch.device(device) if isinstance(device, str) else (device or self.device)
# Get source cell indices (NO COPY of adata)
data_source = source_adata if source_adata is not None else adata
source_mask = data_source.obs[perturbation_key] == source_perturbation
source_indices = np.where(source_mask)[0]
if len(source_indices) == 0:
raise ValueError(f"No cells with {perturbation_key}='{source_perturbation}' found.")
# Sample if requested
if n_samples_from_source is not None and n_samples_from_source < len(source_indices):
random.seed(seed)
source_indices = np.array(random.sample(list(source_indices), n_samples_from_source))
# Extract source counts directly (minimal memory - just the subset we need)
X_source = data_source.layers['counts'][source_indices]
if sparse.issparse(X_source):
X_source = X_source.toarray()
x_ctrl = torch.tensor(X_source, dtype=torch.float32)
# Extract source covariates directly using the TRAINED model's manager
# We need the categorical covariate indices from the training data mapping
source_cat_covs = []
trained_cat_registry = self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY)
for i, cat_key in enumerate(cats):
cat_values = data_source.obs[cat_key].iloc[source_indices].values
mapping = trained_cat_registry.mappings[cat_key]
# Convert category values to indices
cat_indices = []
for val in cat_values:
if val in mapping:
cat_indices.append(np.where(mapping == val)[0][0])
else:
raise ValueError(f"Category value '{val}' not found in training mapping for '{cat_key}'")
source_cat_covs.append(torch.tensor(cat_indices, dtype=torch.long))
cat_covs_tensor = torch.stack(source_cat_covs, dim=1) # (n_cells, n_cats)
# Build new perturbation mapping that includes the target perturbation
old_category_names = list(pert_emb_module._category_names)
new_mapping = list(old_category_names)
if perturbation not in new_mapping:
new_mapping.append(perturbation)
if source_perturbation not in new_mapping:
new_mapping.append(source_perturbation)
# Rebuild embedding for new mapping
pert_emb_module.rebuild_for_mapping(new_mapping)
# Get indices in new mapping
source_pert_idx = new_mapping.index(source_perturbation)
target_pert_idx = new_mapping.index(perturbation)
# Update source cat_covs to use source perturbation index
cat_covs_tensor[:, pert_idx] = source_pert_idx
# Move model to device if needed
module_device = next(self.module.parameters()).device
if module_device != _device:
self.module.to(_device)
# Run prediction in batches
n_cells = len(source_indices)
_batch_size = batch_size if batch_size is not None else n_cells
px_cf_mean_list = []
try:
for start in range(0, n_cells, _batch_size):
end = min(start + _batch_size, n_cells)
x_batch = x_ctrl[start:end].to(_device)
cat_covs_batch = cat_covs_tensor[start:end].to(_device)
# Create counterfactual covariates (change perturbation)
cat_covs_cf = cat_covs_batch.clone()
cat_covs_cf[:, pert_idx] = target_pert_idx
# Forward pass
_, pxs_cf = self.module.sub_forward_cf_avg(
x=x_batch,
cat_covs=cat_covs_batch,
cat_covs_cf=cat_covs_cf,
)
# Average predictions from all decoders
batch_preds = [px_cf.mu for px_cf in pxs_cf if px_cf is not None]
if batch_preds:
batch_mean = torch.stack(batch_preds, dim=0).mean(dim=0)
px_cf_mean_list.append(batch_mean.cpu())
# Clear GPU memory
del x_batch, cat_covs_batch, cat_covs_cf
if torch.cuda.is_available():
torch.cuda.empty_cache()
x_pred = torch.cat(px_cf_mean_list, dim=0)
finally:
# Restore original mapping
pert_emb_module.rebuild_for_mapping(old_category_names)
if module_device != _device:
self.module.to(module_device)
# Get ground truth (on-demand, no copy of full adata)
x_true = None
target_mask = adata.obs[perturbation_key] == perturbation
if target_mask.any():
X_true = adata.layers['counts'][target_mask]
if sparse.issparse(X_true):
X_true = X_true.toarray()
x_true = torch.tensor(X_true, dtype=torch.float32)
return x_ctrl, x_true, x_pred
@torch.no_grad()
def predict_perturbations(
self,
adata: AnnData,
perturbations: List[str],
source_perturbation: str,
cats: List[str],
perturbation_key: str,
new_embeddings: Optional[dict] = None,
n_samples_from_source: Optional[int] = None,
seed: Optional[int] = 0,
device: Optional[Union[str, torch.device]] = None,
batch_size: Optional[int] = 256,
source_adata: Optional[AnnData] = None,
) -> dict:
"""Predict gene expression for multiple perturbations.
Parameters
----------
adata
AnnData object containing cells. Must have a ``'counts'`` layer.
perturbations
List of target perturbation labels.
source_perturbation
Source / control perturbation label.
cats
List of categorical covariate keys.
perturbation_key
Which element of ``cats`` is the perturbation covariate.
new_embeddings
Optional dictionary mapping atomic perturbation names to embedding vectors.
n_samples_from_source
If set, randomly sample this many cells from the source group.
seed
Random seed for reproducibility.
device
Device to run prediction on.
batch_size
Batch size for forward passes. Default 256.
source_adata
Optional AnnData with control cells.
Returns
-------
dict
Dictionary mapping each perturbation to ``(x_ctrl, x_true, x_pred)``.
"""
if not perturbations:
return {}
results = {}
for pert in perturbations:
x_ctrl, x_true, x_pred = self.predict_perturbation(
adata=adata,
perturbation=pert,
source_perturbation=source_perturbation,
cats=cats,
perturbation_key=perturbation_key,
new_embeddings=new_embeddings,
n_samples_from_source=n_samples_from_source,
seed=seed,
device=device,
batch_size=batch_size,
source_adata=source_adata,
)
results[pert] = (x_ctrl, x_true, x_pred)
return results