celldisect.CellDISECT.train

CellDISECT.train(max_epochs: int | None = None, use_gpu: str | int | bool | None = True, train_size: float = 0.8, validation_size: float | None = None, batch_size: int = 256, early_stopping: bool = True, save_best: bool = False, plan_kwargs: dict | None = None, recon_weight: scvi.autotune._types.Tunable.typing.Union[float, int] = 10, cf_weight: scvi.autotune._types.Tunable.typing.Union[float, int] = 1, beta: scvi.autotune._types.Tunable.typing.Union[float, int] = 1, clf_weight: scvi.autotune._types.Tunable.typing.Union[float, int] = 50, adv_clf_weight: scvi.autotune._types.Tunable.typing.Union[float, int] = 10, adv_period: scvi.autotune._types.Tunable.<class 'int'> = 1, n_cf: scvi.autotune._types.Tunable.<class 'int'> = 10, kappa_optimizer2: bool = True, n_epochs_pretrain_ae: int = 0, **trainer_kwargs)[source]

Train the model.

Parameters:
  • max_epochs (Optional[int]) – Number of passes through the dataset. If None, defaults to np.min([round((20000 / n_cells) * 400), 400]).

  • use_gpu (Optional[Union[str, int, bool]]) – Whether to use GPU for training. Can be a boolean, string, or integer specifying the GPU device.

  • train_size (float) – Size of the training set in the range [0.0, 1.0].

  • validation_size (Optional[float]) – Size of the validation set. If None, defaults to 1 - train_size. If train_size + validation_size < 1, the remaining cells belong to a test set.

  • batch_size (int) – Minibatch size to use during training.

  • early_stopping (bool) – Perform early stopping. Additional arguments can be passed in **kwargs. See Trainer for further options.

  • save_best (bool) – Save the best model state with respect to the validation loss (default), or use the final state in the training procedure.

  • plan_kwargs (Optional[dict]) – Keyword arguments for TrainingPlan. Keyword arguments passed to train() will overwrite values present in plan_kwargs, when appropriate.

  • recon_weight (Tunable[Union[float, int]]) – Weight for the reconstruction loss of X.

  • cf_weight (Tunable[Union[float, int]]) – Weight for the reconstruction loss of X_cf.

  • beta (Tunable[Union[float, int]]) – Weight for the KL divergence of Zi.

  • clf_weight (Tunable[Union[float, int]]) – Weight for the Si classifier loss.

  • adv_clf_weight (Tunable[Union[float, int]]) – Weight for the adversarial classifier loss.

  • adv_period (Tunable[int]) – Adversarial training period.

  • n_cf (Tunable[int]) – Number of X_cf reconstructions (a random permutation of n VAEs and a random half-batch subset for each trial).

  • kappa_optimizer2 (bool) – Whether to use the second kappa optimizer.

  • n_epochs_pretrain_ae (int) – Number of epochs to pretrain the autoencoder.

  • **trainer_kwargs – Other keyword arguments for Trainer.

Return type:

None