PyTorch Checkpoint Format

Our large model-optimized checkpoint format is based off the standard HDF5 file format. At a high-level, when saving a checkpoint, the Cerebras stack will take a PyTorch state dictionary, flatten it, and store it in an HDF5 file. For example, the following state dict:

{
    "a": {
        "b": 0.1,
        "c": 0.001,
    },
    "d": [0.1, 0.2, 0.3]
}

Would be flattened and stored into the H5 file as follows

{
    "a.b": 0.1,
    "a.c": 0.001,
    "d.0": 0.1,
    "d.1": 0.2,
    "d.2": 0.3,
}

A model/optimizer state dict can be saved in the new checkpoint format using the cbtorch.save method. e.g.

import cerebras.pytorch as cbtorch

...

state_dict = {
    "model": model.state_dict(),
    "optimizer": optimizer.state_dict(),
}
cbtorch.save(state_dict, "path/to/checkpoint")

...

A checkpoint saved using the above can be loaded using the cbtorch.load method. e.g.

import cerebras.pytorch as cbtorch

...

state_dict = cbtorch.load("path/to/checkpoint")

model.load_state_dict(state_dict["model"])
optimizer.load_state_dict(state_dict["optimizer"])

...

If you’re using the run.py scripts provided in ModelZoo, the configuration and setup mentioned earlier are already handled automatically by the built-in runners.

Converting Checkpoint Formats

If using cbtorch.load is not a sufficient solution for loading the checkpoint into memory, a simple conversion can be done to the pickle format that PyTorch uses as follows

import torch
import cerebras.pytorch as cbtorch

state_dict = cbtorch.load("path/to/checkpoint")
torch.save(state_dict, "path/to/new/checkpoint")