Source code for celldisect.utils

from typing import Dict, List, Optional

import logging
import numpy as np
from torch import nn
import torch
from scvi.nn import one_hot

from enum import Enum

logger = logging.getLogger(__name__)


[docs] class TRAIN_MODE(int, Enum): RECONST = 0 RECONST_CF = 1 KL_Z = 2 CLASSIFICATION = 3 ADVERSARIAL = 4
[docs] class LOSS_KEYS(str, Enum): LOSS = "loss" RECONST_LOSS_X = "rec_x" RECONST_LOSS_X_CF = "rec_x_cf" KL_Z = "kl_z" CLASSIFICATION_LOSS = "ce" ACCURACY = "acc" F1 = "f1"
LOSS_KEYS_LIST = [ LOSS_KEYS.LOSS, LOSS_KEYS.RECONST_LOSS_X, LOSS_KEYS.RECONST_LOSS_X_CF, LOSS_KEYS.KL_Z, LOSS_KEYS.CLASSIFICATION_LOSS, LOSS_KEYS.ACCURACY, LOSS_KEYS.F1 ]
[docs] def one_hot_cat(n_cat_list: List[int], cat_covs: torch.Tensor): cat_list = list() if cat_covs is not None: cat_list = list(torch.split(cat_covs, 1, dim=1)) one_hot_cat_list = [] if len(n_cat_list) > len(cat_list): raise ValueError("nb. categorical args provided doesn't match init. params.") for n_cat, cat in zip(n_cat_list, cat_list): if n_cat and cat is None: raise ValueError("cat not provided while n_cat != 0 in init. params.") if n_cat > 1: # n_cat = 1 will be ignored - no additional information if cat.size(1) != n_cat: onehot_cat = one_hot(cat, n_cat) else: onehot_cat = cat # cat has already been one_hot encoded one_hot_cat_list += [onehot_cat] u_cat = torch.cat(*one_hot_cat_list) if len(one_hot_cat_list) > 1 else one_hot_cat_list[0] return u_cat
############################################################################## # Perturbation prediction utilities ##############################################################################
[docs] def parse_perturbation(name: str, delimiter: str = "+") -> List[str]: """Split a (possibly combinatorial) perturbation name into atomic components. Parameters ---------- name Perturbation label, e.g. ``"GeneA+GeneB"`` or ``"ctrl"``. delimiter Separator used for combinatorial perturbations. Returns ------- List of atomic perturbation names. """ return [comp.strip() for comp in name.split(delimiter)]
[docs] def validate_perturbation_embeddings( adata, perturbation_key: str, embedding_key: str, delimiter: str = "+", ) -> None: """Check that every atomic perturbation in *adata.obs* has an entry in *adata.uns*. Parameters ---------- adata Annotated data object. perturbation_key Column in ``adata.obs`` containing perturbation labels. embedding_key Key in ``adata.uns`` whose value is a ``dict[str, array]`` mapping atomic perturbation names to their vector representations. delimiter Separator used for combinatorial perturbations. Raises ------ KeyError If ``embedding_key`` is not found in ``adata.uns``. ValueError If any atomic perturbation is missing from the embeddings dictionary. """ if embedding_key not in adata.uns: raise KeyError( f"Perturbation embedding key '{embedding_key}' not found in adata.uns. " f"Available keys: {list(adata.uns.keys())}" ) emb_dict = adata.uns[embedding_key] if not isinstance(emb_dict, dict): raise TypeError( f"adata.uns['{embedding_key}'] must be a dict mapping perturbation names " f"to embedding vectors, got {type(emb_dict)}." ) unique_labels = adata.obs[perturbation_key].unique() all_atomic = set() for label in unique_labels: for comp in parse_perturbation(str(label), delimiter): all_atomic.add(comp) missing = all_atomic - set(emb_dict.keys()) if missing: raise ValueError( f"The following atomic perturbations in adata.obs['{perturbation_key}'] " f"do not have embeddings in adata.uns['{embedding_key}']: {sorted(missing)}" ) dims = {k: np.asarray(v).shape for k, v in emb_dict.items() if k in all_atomic} unique_shapes = set(dims.values()) if len(unique_shapes) > 1: raise ValueError( f"All predefined perturbation embeddings must have the same dimensionality. " f"Found shapes: {unique_shapes}" ) logger.info( f"Validated {len(all_atomic)} atomic perturbations across " f"{len(unique_labels)} unique labels." )
[docs] def build_perturbation_embedding_matrix( category_names: List[str], predefined_embeddings: Dict[str, np.ndarray], delimiter: str = "+", ) -> torch.Tensor: """Build an embedding matrix for all perturbation categories. For combinatorial perturbations the component embeddings are summed. Parameters ---------- category_names Ordered list of perturbation category names (from AnnData mapping). predefined_embeddings Dictionary mapping atomic perturbation names to vectors. delimiter Separator for combinatorial perturbations. Returns ------- Tensor of shape ``(n_categories, emb_dim)``. """ rows = [] for cat_name in category_names: components = parse_perturbation(str(cat_name), delimiter) component_vecs = [ torch.as_tensor( np.asarray(predefined_embeddings[comp], dtype=np.float32) ) for comp in components ] combined = torch.stack(component_vecs).sum(dim=0) rows.append(combined) return torch.stack(rows)
[docs] def perturbation_metrics( pred: np.ndarray, true: np.ndarray, ctrl: np.ndarray, top_n_de: int = 20, ) -> dict: """Compute standard perturbation prediction evaluation metrics. Parameters ---------- pred Predicted mean gene expression, shape ``(n_genes,)`` or ``(n_cells, n_genes)``. If 2-D the mean across cells is taken. true Ground-truth mean gene expression (same shape convention). ctrl Control (source) mean gene expression (same shape convention). top_n_de Number of top differentially expressed genes to evaluate. Returns ------- Dictionary with the following keys: - ``pearson_mean`` -- Pearson *r* between predicted and true mean expression - ``pearson_delta`` -- Pearson *r* between predicted and true *delta* (vs ctrl) - ``mse`` -- Mean squared error of mean expression - ``top_de_pearson`` -- Pearson *r* on top-N DE genes (ranked by |true - ctrl|) - ``top_de_cosine`` -- Cosine similarity on top-N DE genes """ from scipy.stats import pearsonr def _to_1d(arr): arr = np.asarray(arr, dtype=np.float64) return arr.mean(axis=0) if arr.ndim == 2 else arr pred = np.log1p(pred) true = np.log1p(true) ctrl = np.log1p(ctrl) pred_mean = _to_1d(pred) true_mean = _to_1d(true) ctrl_mean = _to_1d(ctrl) delta_pred = pred_mean - ctrl_mean delta_true = true_mean - ctrl_mean r_mean, _ = pearsonr(pred_mean, true_mean) r_delta, _ = pearsonr(delta_pred, delta_true) mse = float(np.mean((pred_mean - true_mean) ** 2)) de_ranking = np.argsort(-np.abs(delta_true)) top_idx = de_ranking[:top_n_de] r_de, _ = pearsonr(pred_mean[top_idx], true_mean[top_idx]) cos_num = np.dot(delta_pred[top_idx], delta_true[top_idx]) cos_den = ( np.linalg.norm(delta_pred[top_idx]) * np.linalg.norm(delta_true[top_idx]) + 1e-8 ) cos_de = float(cos_num / cos_den) return { "pearson_mean": float(r_mean), "pearson_delta": float(r_delta), "mse": mse, f"top{top_n_de}_de_pearson": float(r_de), f"top{top_n_de}_de_cosine": cos_de, }
[docs] class PerturbationNetwork(nn.Module): # from CPA code def __init__(self, n_perts, n_latent, doser_type='logsigm', n_hidden=None, n_layers=None, dropout_rate: float = 0.0, drug_embeddings=None,): super().__init__() self.n_latent = n_latent if drug_embeddings is not None: self.pert_embedding = drug_embeddings self.pert_transformation = nn.Linear(drug_embeddings.embedding_dim, n_latent) self.use_rdkit = True else: self.use_rdkit = False self.pert_embedding = nn.Embedding(n_perts, n_latent, padding_idx=CPA_REGISTRY_KEYS.PADDING_IDX) self.doser_type = doser_type if self.doser_type == 'mlp': self.dosers = nn.ModuleList() for _ in range(n_perts): self.dosers.append( FCLayers( n_in=1, n_out=1, n_hidden=n_hidden, n_layers=n_layers, use_batch_norm=False, use_layer_norm=True, dropout_rate=dropout_rate ) ) else: self.dosers = GeneralizedSigmoid(n_perts, non_linearity=self.doser_type)
[docs] def forward(self, perts, dosages): """ perts: (batch_size, max_comb_len) dosages: (batch_size, max_comb_len) """ bs, max_comb_len = perts.shape perts = perts.long() scaled_dosages = self.dosers(dosages, perts) # (batch_size, max_comb_len) drug_embeddings = self.pert_embedding(perts) # (batch_size, max_comb_len, n_drug_emb_dim) if self.use_rdkit: drug_embeddings = self.pert_transformation(drug_embeddings.view(bs * max_comb_len, -1)).view(bs, max_comb_len, -1) z_drugs = torch.einsum('bm,bme->bme', [scaled_dosages, drug_embeddings]) # (batch_size, n_latent) z_drugs = torch.einsum('bmn,bm->bmn', z_drugs, (perts != 0).int()).sum(dim=1) # mask single perts return z_drugs # (batch_size, n_latent)