Restartable Dataloaders
In the Cerebras PyTorch API 2.0, we provide revamped support for deterministically restarting any custom input-generating dataloaders used for a run. This feature enables the saving and loading of the dataloader state and seamlessly integrates with our existing mechanism of capturing checkpoints for a run.
Saving DataLoader State
Similar to how you call state_dict on components such as the model and optimizer to fetch and save state information in our Cerebras H5-based checkpoint format. You can save the state of your dataloader by calling state_dict
on the Cerebras PyTorch dataloader wrapper that must be initialized at the beginning of a custom training loop.
-
Our typical workflow in Model Zoo already includes this call on the Cerebras PyTorch dataloader wrapper to save the state of the dataloader being used for the run.
-
The dataloader state can only be saved at a checkpoint step – i.e. you should wrap the method invoking the call to save the dataloader state in the
checkpoint_closure
decorator.
Loading DataLoader State
Upon restarting a run from a Cerebras checkpoint file, you can fetch the saved dataloader state (if it exists) from the loaded checkpoint and pass it to the load_state_dict
method on the Cerebras PyTorch dataloader wrapper to load your dataloader’s state, e.g.
And that is all!
Now to specify what “state” information of your dataloader is to be saved in a checkpoint when state_dict is called on the Cerebras PyTorch dataloader, and how this state information should be loaded to rewind your dataloader, your dataloader must conform to the protocol class RestartableDataLoader
.
Restartable DataLoader API
By implementing methods state_dict
, aggregate_state_dict
, deaggregate_state_dict
and load_state_dict
, with the appropriate method signatures, your dataloader is guaranteed to be restartable. That is, you are able to save the state of your dataloader in a checkpoint and load it by the mechanism described above.
Recall that in a distributed setting, each input worker per CSX creates its own instance of the dataloader for parallelism. Thus, implementing these four methods will determine how your dataloader’s state should be saved and loaded to enable deterministic restarts for such settings.
To illustrate the usage of this protocol with an example, we define our CustomRestartableDataLoader class below. The following subsections describe each method signature more generally and within the context of our custom class.
state_dict
Use this method to specify what state information each input-generating worker should capture at an appliance checkpoint step. By default, each worker captures some internal state info using our new Cerebras dataloader checkpoint format defined by the DataLoaderCheckpoint
dataclass. Please refer to the linked docs on this class for detailed information on each attribute. Essentially, in your definition of state_dict you may choose to save any of the aforementioned internal state info per worker. We expose an API method get_worker_state
that you may utilize in your implementation of state_dict to fetch the worker’s internal state info, e.g.
-
The call to
get_worker_state
is well-defined only inside of your implementation of state_dict; calling this method anywhere else will result in a RunimeError exception. -
Ensure that any other state info you choose to save must be picklable using the dill package.
aggregate_state_dict
This method accepts the list of individual worker states dicts as an argument. Each state dict inside this list holds per-worker state information as defined in your implementation of the state_dict signature method.
Use this method to specify how to combine the state information of all workers in a single, consolidated state dict, e.g.
-
The aggregated state dict represents the state of your dataloader and will eventually be saved in our Cerebras H5 checkpoint file when state_dict is invoked on the Cerebras PyTorch dataloader wrapper to save your dataloader’s state.
-
In the example above, we’re assuming two total workers used for the run. In the aggregated state dict, we are choosing to save worker 0’s step, worker 1’s global worker id, and the summed step count of both workers as the state of our dataloader.
-
You can expect the worker_states list to be ordered by the global worker id of each worker.
deaggregate_state_dict
This method accepts an aggregated state dict as an argument. The aggregated state dict represents the state of your dataloader, as specified in the aggregate_state_dict method signature of your dataloader.
To load your data loader’s state, use this method to specify how the consolidated dataloader state loaded from a checkpoint should be disaggregated into a single state dict defining how each worker should load its state, e.g.
In the example above, our implementation has an explicit check to ensure that we’re loading state captured by this dataloader. Upon restart, we assume that each worker cares about the combined step count of all workers in the previous run at the checkpoint we’re loading from; thus, the deaggregation method constructs and returns a single state holding the combined step info.
This method will be particularly useful when the number of workers per box changes between subsequent runs; use this to specify which state dict should be loaded by each worker upon restart.
load_state_dict
This method accepts a disaggregated state dict as an argument, as defined in your implementation of deaggregate_state_dict.
Use this method to specify how the worker should load its state from the provided, disaggregated state dict, e.g.
Again, we have an explicit check to ensure that the disaggregated state dict being used by each worker to load its state upon restart is the same as that specified by our data loader’s implementation of deaggregate_state_dict. For this example, each worker simply prints the combined step count from the previous run, but you can imagine using this step count to set other properties on your data loader that enable it to restart deterministically.
Putting it All Together
Combining all of the above steps, we have the following steps to set up our custom restartable dataloader whose state can be captured via checkpointing:
It is not necessary for your dataloader to be of type torch.utils.data.DataLoader to enable the saving and loading of its state; in fact, any iterable that returns a structure comprising torch tensors can be programmed to be restartable as long as it implements the four signature methods conforming to the Cerebras RestartableDataLoader
protocol class.