Utilities

Utility Functions

class celldisect.utils.TRAIN_MODE(value)[source]

Bases: int, Enum

An enumeration.

RECONST = 0
RECONST_CF = 1
KL_Z = 2
CLASSIFICATION = 3
ADVERSARIAL = 4
__format__(format_spec)

Returns format using actual value type unless __str__ has been overridden.

class celldisect.utils.LOSS_KEYS(value)[source]

Bases: str, Enum

An enumeration.

LOSS = 'loss'
RECONST_LOSS_X = 'rec_x'
RECONST_LOSS_X_CF = 'rec_x_cf'
KL_Z = 'kl_z'
CLASSIFICATION_LOSS = 'ce'
ACCURACY = 'acc'
F1 = 'f1'
__format__(format_spec)

Returns format using actual value type unless __str__ has been overridden.

celldisect.utils.one_hot_cat(n_cat_list: List[int], cat_covs: torch.Tensor)[source]
celldisect.utils.parse_perturbation(name: str, delimiter: str = '+') List[str][source]

Split a (possibly combinatorial) perturbation name into atomic components.

Parameters:
  • name – Perturbation label, e.g. "GeneA+GeneB" or "ctrl".

  • delimiter – Separator used for combinatorial perturbations.

Return type:

List of atomic perturbation names.

celldisect.utils.validate_perturbation_embeddings(adata, perturbation_key: str, embedding_key: str, delimiter: str = '+') None[source]

Check that every atomic perturbation in adata.obs has an entry in adata.uns.

Parameters:
  • adata – Annotated data object.

  • perturbation_key – Column in adata.obs containing perturbation labels.

  • embedding_key – Key in adata.uns whose value is a dict[str, array] mapping atomic perturbation names to their vector representations.

  • delimiter – Separator used for combinatorial perturbations.

Raises:
  • KeyError – If embedding_key is not found in adata.uns.

  • ValueError – If any atomic perturbation is missing from the embeddings dictionary.

celldisect.utils.build_perturbation_embedding_matrix(category_names: List[str], predefined_embeddings: Dict[str, numpy.ndarray], delimiter: str = '+') torch.Tensor[source]

Build an embedding matrix for all perturbation categories.

For combinatorial perturbations the component embeddings are summed.

Parameters:
  • category_names – Ordered list of perturbation category names (from AnnData mapping).

  • predefined_embeddings – Dictionary mapping atomic perturbation names to vectors.

  • delimiter – Separator for combinatorial perturbations.

Return type:

Tensor of shape (n_categories, emb_dim).

celldisect.utils.perturbation_metrics(pred: numpy.ndarray, true: numpy.ndarray, ctrl: numpy.ndarray, top_n_de: int = 20) dict[source]

Compute standard perturbation prediction evaluation metrics.

Parameters:
  • pred – Predicted mean gene expression, shape (n_genes,) or (n_cells, n_genes). If 2-D the mean across cells is taken.

  • true – Ground-truth mean gene expression (same shape convention).

  • ctrl – Control (source) mean gene expression (same shape convention).

  • top_n_de – Number of top differentially expressed genes to evaluate.

Returns:

  • Dictionary with the following keys

    • pearson_mean – Pearson r between predicted and true mean expression

    • pearson_delta – Pearson r between predicted and true delta (vs ctrl)

    • mse – Mean squared error of mean expression

    • top_de_pearson – Pearson r on top-N DE genes (ranked by |true - ctrl|)

    • top_de_cosine – Cosine similarity on top-N DE genes

class celldisect.utils.PerturbationNetwork(*args: Any, **kwargs: Any)[source]

Bases: Module

forward(perts, dosages)[source]

perts: (batch_size, max_comb_len) dosages: (batch_size, max_comb_len)