Trainer
class. By the end you should have a cursory understanding on how to use the Trainer
class.
Trainer
class.
Trainer
class can be imported and used as follows:
Trainer
class takes in the following:
device
: The device to run training/validation on.
model_dir
: The directory at which to store model related artifacts (e.g. model checkpoints).
model
: The torch.nn.Module
instance that we are training/validating.
optimizer
: Optionally, a cerebras.pytorch.optim.Optimizer
instance can be passed in to optimize the model weights during the training phase.
fit
takes in the following:
train_dataloader
: The cerebras.pytorch.utils.data.DataLoader
instance to use during training.
val_dataloader
: Optionally, a cerebras.pytorch.utils.data.DataLoader
instance can be passed in to run validation during and/or at the end of training.
Trainer
to fit your needs.
train_dataloader
and val_dataloader
are provided to the fit
call, the default behaviour is to run a single epoch of training followed by a single epoch of validation.
This behaviour can be configured by passing in a TrainingLoop
instance to the Trainer as follows:
num_steps
represents the total number of batches to train for. If num_steps
exceeds the number of available batches in the train dataloader, the dataloader is automatically repeated to be able to run training for num_steps
.
eval_steps
represents the number of steps to run validation for every time we run validation. Similar to training, if eval_steps
exceeds the number of available batches in the val dataloader, the dataloader is automatically repeated. Although, typically validation is never run for more than a single epoch. So, it is advised to set eval_steps
to be less than the length of the validation dataloader. Otherwise, the validation metrics may be incorrect.
eval_frequency
represents how often validation is run during training. In the above example, validation is run every 100 steps of training. That is to say, throughout the 1000 steps of training, validation is run 10 times. Regardless of the value of eval_frequency
, if eval_frequency
is greater than zero, we always run validation at the end of training.
Trainer
can be further configured to save checkpoints at regular intervals by passing in a Checkpoint
instance as follows:
num_steps
is a multiple of the checkpoint steps.
The checkpoints are saved in the model_dir
directory that was passed to the Trainer
.
ckpt_path
argument to the call to fit
. For example,
ckpt_path
is not provided, but a checkpoint is found inside the model_dir
, then Trainer
“cerebras.modelzoo.Trainer”) will automatically load the latest checkpoint found in the model_dir
.Trainer
, see Checkpointing.
Trainer
instance using a YAML configuration file, you can check out:
Trainer
in some core workflows, you can check out:
To learn more about how you can extend the capabilities of the Trainer
class, you can check out:
Trainer
class outputs during the run, you can check out: