celldisect.PerturbationEmbedding

class celldisect.PerturbationEmbedding(*args: Any, **kwargs: Any)[source]

Bases: Module

Embedding module backed by predefined external vectors (e.g. ESM, GenePT).

For combinatorial perturbations (e.g. "GeneA+GeneB"), component embeddings are summed. The weight matrix is frozen by default.

Parameters:
  • predefined_embeddings – Mapping from atomic perturbation name to its vector representation.

  • category_names – Ordered list of category names that mirrors the integer-index mapping produced by scvi-tools’ CategoricalJointObsField.

  • combination_delimiter – String used to split combinatorial perturbation labels.

Methods

__init__(predefined_embeddings, category_names)

add_perturbation(name[, embedding])

Register a new perturbation (possibly unseen during training).

forward(indices)

rebuild_for_mapping(new_category_names)

Rebuild the weight matrix for a different category-to-index mapping.

Attributes

weight

Alias so callers expecting nn.Embedding attributes still work.

add_perturbation(name: str, embedding: numpy.ndarray | None = None) int[source]

Register a new perturbation (possibly unseen during training).

Parameters:
  • name – Perturbation label. Can be combinatorial ("A+B").

  • embedding – Vector for an atomic perturbation that is not yet in the predefined dictionary. Ignored for names already known.

Return type:

Integer index of the (new or existing) perturbation.

rebuild_for_mapping(new_category_names: list) None[source]

Rebuild the weight matrix for a different category-to-index mapping.

property weight

Alias so callers expecting nn.Embedding attributes still work.