Writing a Custom Training Loop
Learn how to write a custom training loop for a simple, fully connected model on the MNIST dataset.
Our typical workflow involves using a training script provided in the Cerebras Model Zoo. However, if that training loop is insufficient for your model needs, you may write your own training loop using the Cerebras PyTorch API.
The following steps will only take you through the absolute minimum code required to run a simple, small model on the Cerebras Wafer Scale Cluster. To extend the script to feature things like learning rate scheduling, gradient scaling, etc. please continue to the further reading section to learn more about these topics.
Prerequisites
You have installed the cerebras.pytorch package in your environment.
Validate the Package Installation
Run the following command to validate that cerebras.pytorch package is installed correctly:
From here on, we will be using cstorch
as the alias for cerebras.pytorch
Configure the Wafer Scale Cluster
To configure the Cerebras Wafer-Scale cluster, construct a ClusterConfig
object, then use that to construct a backend
object:
See the class documentation for ClusterConfig
to view all configurable options.
Most options have reasonable defaults and do not need to be changed.
Define Your Model
When using the Cerebras PyTorch API, define your model in the same way you would in a Vanilla PyTorch workflow:
Weight initialization for large models can cause out-of-memory errors. Not only that, but initializing extremely large models eagerly can be very slow. See the page on Efficient weight initialization to see how to work around this issue.
Compile Your Model
Once the model has been instantiated, compile the model by calling the cerebras.pytorch.compile
, for example:
The call to cstorch.compile
doesn’t actually compile the model. Similar to torch.compile
it only prepares the model for compilation. Compilation only happens after the first iteration, once the input shapes are known.
Optimize Model Parameters
To optimize model parameters using the Cerebras Wafer-Scale cluster, you must use a Cerebras-compliant optimizer. There are exact drop-in replacements for all commonly used optimizers available in cerebras.pytorch.optim, for example:
If you are interested in writing your own Cerebras custom-compliant optimizer, see the page on Writing Custom Optimizers.
Dataloaders
To send data to the Wafer-Scale cluster, wrap your PyTorch dataloader with cerebras.pytorch.utils.data.DataLoader
.
For example:
The Cerebras PyTorch dataloader requires a callable that generates a PyTorch dataloader. This approach ensures that each worker can independently create its own dataloader instance, optimizing distributed parallelism.
Define the Training Step
To execute a single training iteration on the Wafer-Scale Cluster, you first need to capture all operations intended to run on the cluster. Do this by defining a function that includes all actions for a single training iteration and decorating it with cerebras.pytorch.trace
.
For example:
This function is traced and sent to the cluster for compilation and execution.
Define an Execution
To program an execution run on the Cerebras Wafer-Scale cluster, define an instance of the cerebras.pytorch.utils.data.DataExecutor
.
For example:
It takes in the Cerebras PyTorch dataloader that will be used during the run, the total number of steps to run for, as well as the interval at which checkpoints will be taken.
Train Your Model
Once the above is defined, you can iterate through the executor to train your model.
-
Notice how the loss was passed into a function decorated by
step_closure
. This is required to retrieve the loss value from the cluster before it can be used. See the page on step closures for more details. -
Also, notice how checkpoints are saved inside a function decorated by
checkpoint_closure
. This is required to retrieve the model weights and optimizer state back from the cluster before it can be saved. Please see the page on saving checkpoints.
Putting It All Together
Combining all of the above steps, you create a super minimal training script for a simple, fully connected model training on the MNIST dataset: