PyTorch Checkpoint Format

Our large model-optimized checkpoint format is based off the standard HDF5 file format. At a high-level, when saving a checkpoint, the Cerebras stack will take a PyTorch state dictionary, flatten it, and store it in an HDF5 file.

For example, the following state dictionary is flattened and stored into the H5 file as follows:

{
    "a": {
        "b": 0.1,
        "c": 0.001,
    },
    "d": [0.1, 0.2, 0.3]
}

Flattened H5:

{
    "a.b": 0.1,
    "a.c": 0.001,
    "d.0": 0.1,
    "d.1": 0.2,
    "d.2": 0.3,
}

A model/optimizer state dictionary can be saved in the new checkpoint format using the cstorch.save method:

import cerebras.pytorch as cstorch

...

state_dict = {
    "model": model.state_dict(),
    "optimizer": optimizer.state_dict(),
}
cstorch.save(state_dict, "path/to/checkpoint")

...

A checkpoint saved using the above can be loaded using the cstorch.load method:

import cerebras.pytorch as cstorch

...

state_dict = cstorch.load("path/to/checkpoint")

model.load_state_dict(state_dict["model"])
optimizer.load_state_dict(state_dict["optimizer"])

...

Convert to Pickle Format

If using cstorch.load is not a sufficient solution for loading the checkpoint into memory, a simple conversion can be done to the pickle format:

import torch
import cerebras.pytorch as cstorch

state_dict = cstorch.load("path/to/checkpoint")
torch.save(state_dict, "path/to/new/checkpoint")