Writing a Custom Training Loop
Our typical workflow involves using a training script provided in the Cerebras Model Zoo. However, if that training loop is insufficient for your model needs, you may write your own training loop using the Cerebras PyTorch API.
Proceed with the following steps to learn how to write a custom training loop for a simple, fully connected model for training on the MNIST dataset.
Note, the following steps will only take you through the absolute minimum code required to run a simple, small model on the Cerebras Wafer Scale Cluster. To extend the script to feature things like learning rate scheduling, gradient scaling, etc. please continue to the further reading section to learn more about these topics.
Prerequisites
You have installed the cerebras.pytorch package in your environment.
Validate the Package Installation
To check whether the cerebras.pytorch package is installed correctly, issue the following command:
From here on, we will be using cstorch
as the alias for cerebras.pytorch
Define Your Model
When using the Cerebras PyTorch API, you can define your model in the same way you would in a Vanilla PyTorch workflow:
Weight initialization for large models can cause out-of-memory errors. Not only that, but initializing extremely large models eagerly can be very slow. See the page on Efficient weight initialization to see how to work around this issue.
Compile Your Model
Once the model has been instantiated, compile the model by calling the cerebras.pytorch.compile
, e.g.
You must pass in the backend you wish to compile the model with. You can simply pass in the type of backend if you wish to use all default arguments, or you can instantiate a backend object using cerebras.pytorch.backend
to customize it, e.g.
The call to cstorch.compile
doesn’t actually compile the model. Similar to torch.compile
it only prepares the model for compilation. Compilation only happens after the first iteration once the input shapes are known.
Optimize Model Parameters
To optimize model parameters using the Cerebras Wafer-Scale cluster, you must use a Cerebras-compliant optimizer. There are exact drop-in replacements for all commonly used optimizers available in cerebras.pytorch.optim, e.g.
If you are interested in writing your own Cerebras custom-compliant optimizer, see the page on Writing custom optimizers
DataLoaders
To send data to the Wafer-Scale cluster, you must wrap your PyTorch dataloader with cerebras.pytorch.utils.data.DataLoader
, e.g.
The Cerebras PyTorch dataloader takes in some callable that returns a PyTorch dataloader. It must be done this way so that every single worker can create their own PyTorch dataloader instance to maximize distributed parallelism.
Define the Training Step
To run a single training iteration on the Cerebras Wafer-Scale cluster, we must first, capture everything that is intended to run on the cluster. To do this, define a function which contains everything that happens in a single training iteration, and decorate it using cerebras.pytorch.trace
.
For example:
This function gets traced and sent to the cluster for compilation and execution.
Define an Execution
To program an execution run on the Cerebras Wafer-Scale cluster, you must define an instance of the cerebras.pytorch.utils.data.DataExecutor
, e.g.
It takes in the Cerebras PyTorch dataloader that will be used during the run, the total number of steps to run for, as well as the interval at which checkpoints will be taken.
Configuring the Cerebras Wafer Scale Cluster
To configure the Cerebras Wafer-Scale cluster, construct a CSConfig
object:
which can be passed to the cerebras.pytorch.utils.data.DataExecutor
object
See the class documentation for CSConfig
for all the options configurable.
Most options have reasonable defaults and do not need to be changed.
Train Your Model
Once the above is defined, you can iterate through the executor to train your model.
-
Notice how the loss was passed into a function decorated by
step_closure
. This is required to retrieve the loss value from the Cerebras Wafer Scale Cluster before it can be used. Please see the page on step closures for more details. -
Also, notice how checkpoints are saved inside a function decorated by
checkpoint_closure
. This is required to retrieve the model weights and optimizer state back from the Cerebras Wafer Scale Cluster before it can be saved. Please see the page on saving checkpoints.
Putting it All Together
Combining all of the above steps, we can create a super minimal training script for a simple, fully connected model training on the MNIST dataset: