state_dict
on the Cerebras PyTorch dataloader wrapper that must be initialized at the beginning of a custom training loop.
checkpoint_closure
decorator.
load_state_dict
method on the Cerebras PyTorch dataloader wrapper to load your dataloader’s state, e.g.
RestartableDataLoader
.
state_dict
, aggregate_state_dict
, deaggregate_state_dict
and load_state_dict
, with the appropriate method signatures, your dataloader is guaranteed to be restartable. That is, you are able to save the state of your dataloader in a checkpoint and load it by the mechanism described above.
Recall that in a distributed setting, each input worker per CSX creates its own instance of the dataloader for parallelism. Thus, implementing these four methods will determine how your dataloader’s state should be saved and loaded to enable deterministic restarts for such settings.
To illustrate the usage of this protocol with an example, we define our CustomRestartableDataLoader class below. The following subsections describe each method signature more generally and within the context of our custom class.
DataLoaderCheckpoint
dataclass. Please refer to the linked docs on this class for detailed information on each attribute. Essentially, in your definition of state_dict you may choose to save any of the aforementioned internal state info per worker. We expose an API method get_worker_state
that you may utilize in your implementation of state_dict to fetch the worker’s internal state info, e.g.
get_worker_state
is well-defined only inside of your implementation of state_dict; calling this method anywhere else will result in a RunimeError exception.
RestartableDataLoader
protocol class.