celldisect.PerturbationEmbedding
- class celldisect.PerturbationEmbedding(*args: Any, **kwargs: Any)[source]
Bases:
ModuleEmbedding module backed by predefined external vectors (e.g. ESM, GenePT).
For combinatorial perturbations (e.g.
"GeneA+GeneB"), component embeddings are summed. The weight matrix is frozen by default.- Parameters:
predefined_embeddings – Mapping from atomic perturbation name to its vector representation.
category_names – Ordered list of category names that mirrors the integer-index mapping produced by scvi-tools’
CategoricalJointObsField.combination_delimiter – String used to split combinatorial perturbation labels.
Methods
__init__(predefined_embeddings, category_names)add_perturbation(name[, embedding])Register a new perturbation (possibly unseen during training).
forward(indices)rebuild_for_mapping(new_category_names)Rebuild the weight matrix for a different category-to-index mapping.
Attributes
Alias so callers expecting
nn.Embeddingattributes still work.- add_perturbation(name: str, embedding: numpy.ndarray | None = None) int[source]
Register a new perturbation (possibly unseen during training).
- Parameters:
name – Perturbation label. Can be combinatorial (
"A+B").embedding – Vector for an atomic perturbation that is not yet in the predefined dictionary. Ignored for names already known.
- Return type:
Integer index of the (new or existing) perturbation.
- rebuild_for_mapping(new_category_names: list) None[source]
Rebuild the weight matrix for a different category-to-index mapping.
- property weight
Alias so callers expecting
nn.Embeddingattributes still work.