celldisect.CellDISECT.train
- CellDISECT.train(max_epochs: int | None = None, use_gpu: str | int | bool | None = True, train_size: float = 0.8, validation_size: float | None = None, batch_size: int = 256, early_stopping: bool = True, save_best: bool = False, plan_kwargs: dict | None = None, recon_weight: scvi.autotune._types.Tunable.typing.Union[float, int] = 10, cf_weight: scvi.autotune._types.Tunable.typing.Union[float, int] = 1, beta: scvi.autotune._types.Tunable.typing.Union[float, int] = 1, clf_weight: scvi.autotune._types.Tunable.typing.Union[float, int] = 50, adv_clf_weight: scvi.autotune._types.Tunable.typing.Union[float, int] = 10, adv_period: scvi.autotune._types.Tunable.<class 'int'> = 1, n_cf: scvi.autotune._types.Tunable.<class 'int'> = 10, kappa_optimizer2: bool = True, n_epochs_pretrain_ae: int = 0, **trainer_kwargs)[source]
Train the model.
- Parameters:
max_epochs (Optional[int]) – Number of passes through the dataset. If None, defaults to np.min([round((20000 / n_cells) * 400), 400]).
use_gpu (Optional[Union[str, int, bool]]) – Whether to use GPU for training. Can be a boolean, string, or integer specifying the GPU device.
train_size (float) – Size of the training set in the range [0.0, 1.0].
validation_size (Optional[float]) – Size of the validation set. If None, defaults to 1 - train_size. If train_size + validation_size < 1, the remaining cells belong to a test set.
batch_size (int) – Minibatch size to use during training.
early_stopping (bool) – Perform early stopping. Additional arguments can be passed in **kwargs. See
Trainerfor further options.save_best (bool) – Save the best model state with respect to the validation loss (default), or use the final state in the training procedure.
plan_kwargs (Optional[dict]) – Keyword arguments for
TrainingPlan. Keyword arguments passed to train() will overwrite values present in plan_kwargs, when appropriate.recon_weight (Tunable[Union[float, int]]) – Weight for the reconstruction loss of X.
cf_weight (Tunable[Union[float, int]]) – Weight for the reconstruction loss of X_cf.
beta (Tunable[Union[float, int]]) – Weight for the KL divergence of Zi.
clf_weight (Tunable[Union[float, int]]) – Weight for the Si classifier loss.
adv_clf_weight (Tunable[Union[float, int]]) – Weight for the adversarial classifier loss.
adv_period (Tunable[int]) – Adversarial training period.
n_cf (Tunable[int]) – Number of X_cf reconstructions (a random permutation of n VAEs and a random half-batch subset for each trial).
kappa_optimizer2 (bool) – Whether to use the second kappa optimizer.
n_epochs_pretrain_ae (int) – Number of epochs to pretrain the autoencoder.
**trainer_kwargs – Other keyword arguments for
Trainer.
- Return type:
None