Saving Loading Checkpoint
To save and load weights in a Cerebras run, we provide a custom Cerebras H5 based checkpoint format that is far more performant and efficient compared to the core PyTorch Pickle-based checkpoint format, especially when it comes to any models with extremely large weights, such as Large Language Models (LLMs). To save a checkpoint, we provide a cerebras.pytorch.save
function that you can use in exactly the same way as torch.save
:
Similarly, we provide a cerebras.pytorch.load
function that can also be used in exactly the same way as torch.load
:
Checkpoint Closures
It is only possible to fetch weights on predetermined checkpoint steps configured using the DataExecutor
. The reason this is so, is to make training more performant.
For example, if the configuration was checkpoint_steps=100
, you are only allowed to fetch the weights to take a checkpoint every 100th step and at the very end on the last step.
To aid this, you can use the checkpoint_closure
decorator which is a step closure that checks that the current step is a checkpoint step before calling the function. In addition, using this decorator ensures that the weights are available to fetch from the server before they can be saved to the checkpoint file.
Converting Checkpoints to a Pickle-Based Format
If you have a checkpoint in the Cerebras H5-based format and wish to use it in a CPU/GPU workflow, it can easily be converted to a PyTorch pickle-based format:
This will eagerly load the entirety of the checkpoint into memory. Thus, it may cause memory issues when loading checkpoints for very large models.
In most cases, converting the checkpoint is unnecessary as the lazily loaded checkpoint acquired via the cerebras.pytorch.load
function will also work in CPU/GPU workflows.