> ## Documentation Index
> Fetch the complete documentation index at: https://training-docs.cerebras.ai/llms.txt
> Use this file to discover all available pages before exploring further.

#  Trainer API

class cerebras.modelzoo.Trainer(device=None, backend=None, model\_dir=`<object object>`, model=`<object object>`, optimizer=None, schedulers=None, precision=None, sparsity=None, loop=None, checkpoint=None, logging=None, callbacks=None, loggers=None, seed=None)

The Trainer class is the main entry point for training models in ModelZoo.

Parameters.

* **device** (*Optional*\[*str*]) – The device to train the model on. It must be one of “CSX”, “CPU”, or “GPU”.

* **backend** (*Optional*\[*Backend*]) – The backend used to train the model. This argument is mutually exclusive with device.

* **model\_dir** (*str*) – The directory where the model artifacts are saved.

* **model** (*Union*\[*Callable*\[**\[**]**,**[*torch.nn.Module*](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module "(in PyTorch v2.4)")**]**,[torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module "(in PyTorch v2.4)")*]*) –The model to train. It must be one of the following:

  * If a callable is passed, it is assumed to be a function that takes in no arguments returns a torch.nn.Module.

  * If a torch.nn.Module is passed, it is used as is.

* **optimizer** (*Union*\_\[[*Optimizer*](../../api/cerebras_pytorch/optim.html#cerebras.pytorch.optim.Optimizer "cerebras.pytorch.optim.Optimizer"),\_ *Callable*\_\[**\[**[*torch.nn.Module*](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module "(in PyTorch v2.4)")**]**,\_ [*Optimizer*](../../api/cerebras_pytorch/optim.html#cerebras.pytorch.optim.Optimizer "cerebras.pytorch.optim.Optimizer")*]\_\_,* *None*\_]\_) –The optimizer used to optimize the model. It must be one of the following:

  * If a [`Optimizer`](../../api/cerebras_pytorch/optim.html#cerebras.pytorch.optim.Optimizer "cerebras.pytorch.optim.Optimizer") is passed, it is used as is.

  * If a callable is passed, it is assumed to be a function that takes in a torch.nn.Module and returns a [`Optimizer`](../../api/cerebras_pytorch/optim.html#cerebras.pytorch.optim.Optimizer "cerebras.pytorch.optim.Optimizer").

  * If not passed, then assume that only validation will be run.

* **schedulers** (*SchedulersInput*) –The set of optimizer schedulers to be used. Common schedulers include LR schedulers. It must be a list of these items:

  * If a cstorch.optim.scheduler.Scheduler is passed, it is used as is.

  * A callable that is assumed to be a function that takes in a [`Optimizer`](../../api/cerebras_pytorch/optim.html#cerebras.pytorch.optim.Optimizer "cerebras.pytorch.optim.Optimizer") and returns a cstorch.optim.scheduler.Scheduler.

  * If None, there is no optimizer param group scheduling.

* **precision** (*Optional*\_\[[*Precision*](generated/cerebras.modelzoo.trainer.callbacks.html#cerebras.modelzoo.trainer.callbacks.Precision "cerebras.modelzoo.trainer.callbacks.Precision")]\_) – The Precision callback used during training

* **sparsity** (*Optional*\_\[[*SparsityAlgorithm*](../../api/cerebras_pytorch/sparse.html#cerebras.pytorch.sparse.SparsityAlgorithm "cerebras.pytorch.sparse.SparsityAlgorithm")]\_) –The sparsity algorithm used to sparsify weights during training/validation It must be one of the following:

  * If a callable is passed, it is assumed to be a function that takes in no arguments returns a [`SparsityAlgorithm`](../../api/cerebras_pytorch/sparse.html#cerebras.pytorch.sparse.SparsityAlgorithm "cerebras.pytorch.sparse.SparsityAlgorithm").

  * If a [`SparsityAlgorithm`](../../api/cerebras_pytorch/sparse.html#cerebras.pytorch.sparse.SparsityAlgorithm "cerebras.pytorch.sparse.SparsityAlgorithm") is passed, it is used as is.

* **loop** (*Optional*\_\[[*LoopCallback*](generated/cerebras.modelzoo.trainer.callbacks.html#cerebras.modelzoo.trainer.callbacks.LoopCallback "cerebras.modelzoo.trainer.callbacks.LoopCallback")]\_) – The loop callback to use for training. It must be an instance of LoopCallback. If not provided, the default loop is TrainingLoop(num\_epochs=1).

* **checkpoint** (*Optional*\_\[[*Checkpoint*](generated/cerebras.modelzoo.trainer.callbacks.html#cerebras.modelzoo.trainer.callbacks.Checkpoint "cerebras.modelzoo.trainer.callbacks.Checkpoint")]\_) – The checkpoint callback to use for saving/loading checkpoints. It must be an instance of Checkpoints. If not provided, then no checkpoints are saved.

* **logging** (*Optional*\_\[[*Logging*](generated/cerebras.modelzoo.trainer.callbacks.html#cerebras.modelzoo.trainer.callbacks.Logging "cerebras.modelzoo.trainer.callbacks.Logging")]\_) – The logging callback used to set up python logging. This callback also controls when logs are supposed to be logged. If not provided, the default logging settings `Logging(log_steps=1, log_level="INFO")` are used.

* **callbacks** (*Optional*\_\[**List**\[[*Callback*](generated/cerebras.modelzoo.trainer.callbacks.html#cerebras.modelzoo.trainer.callbacks.Callback "cerebras.modelzoo.trainer.callbacks.Callback")]\_\_]\_) – A list of callbacks to used by the trainer. The order in which the callbacks are provided is important as it determines the order in which the callback’s hooks are executed.

* **loggers** (*Optional*\_\[**List**\[[*Logger*](generated/cerebras.modelzoo.trainer.loggers.html#cerebras.modelzoo.trainer.loggers.Logger "cerebras.modelzoo.trainer.loggers.Logger")]\_\_]\_) – A list of loggers to use for logging.

* **seed** (*Optional*\_\[**int**]\_) – Initial seed for the torch random number generator.

*property* all\_callbacks\_:  Generator\[[cerebras.modelzoo.trainer.callbacks.callback.Callback](generated/cerebras.modelzoo.trainer.callbacks.html#cerebras.modelzoo.trainer.callbacks.Callback "cerebras.modelzoo.trainer.callbacks.callback.Callback"),  None,  None]\_[#](#cerebras.modelzoo.Trainer.all_callbacks "Permalink to this definition")

Get all callback objects available to the trainer.

get\_callbacks(*callback\_type*)[\[source\]](../_modules/cerebras/modelzoo/trainer/trainer.html#Trainer.get_callbacks)[#](#cerebras.modelzoo.Trainer.get_callbacks "Permalink to this definition")

Get all callbacks of the given type.

get\_callback(*callback\_type*)[\[source\]](../_modules/cerebras/modelzoo/trainer/trainer.html#Trainer.get_callback)[#](#cerebras.modelzoo.Trainer.get_callback "Permalink to this definition")

Get the first callback of the given type.

*property* validation\_callbacks\_:  List\[[cerebras.modelzoo.trainer.callbacks.callback.ValidationCallback](generated/cerebras.modelzoo.trainer.callbacks.html#cerebras.modelzoo.trainer.callbacks.ValidationCallback "cerebras.modelzoo.trainer.callbacks.callback.ValidationCallback")]\_[#](#cerebras.modelzoo.Trainer.validation_callbacks "Permalink to this definition")

Returns all validation callbacks in the Trainer’s callback list.

call(*hook\_name*, *\*args*, *\*\*kwargs*)[\[source\]](../../../_modules/cerebras/modelzoo/trainer/trainer.html#Trainer.call)[#](#cerebras.modelzoo.Trainer.call "Permalink to this definition")

Call the hook with name hook\_name for all callbacks in the Trainer’s callback list as well as the callbacks in the global registry.

The callback’s method is passed in the trainer object itself as well as any args and kwargs that are passed into this method. e.g.

Parameters

* **hook\_name** (*str*) – The name of the hook to call. It must be the name of a method in the Callback class.

* **args** – Other positional arguments to forward along to the called hook.

* **kwargs** – Other keyword arguments to forward along to the called hook.

*property* precision\_:  Optional\[[cerebras.modelzoo.trainer.callbacks.precision.Precision](generated/cerebras.modelzoo.trainer.callbacks.html#cerebras.modelzoo.trainer.callbacks.Precision "cerebras.modelzoo.trainer.callbacks.precision.Precision")]\_[#](#cerebras.modelzoo.Trainer.precision "Permalink to this definition")

Returns the precision callback instance if it exists.

*property* grad\_accum\_:  [cerebras.modelzoo.trainer.callbacks.grad\_accum.GradientAccumulationCallback](generated/cerebras.modelzoo.trainer.callbacks.html#cerebras.modelzoo.trainer.callbacks.GradientAccumulationCallback "cerebras.modelzoo.trainer.callbacks.grad_accum.GradientAccumulationCallback")\_[#](#cerebras.modelzoo.Trainer.grad_accum "Permalink to this definition")

Returns the gradient accumulation callback instance.

*property* should\_run\_optimizer\_step\_:  bool\_[#](#cerebras.modelzoo.Trainer.should_run_optimizer_step "Permalink to this definition")

Returns True if we should run the optimizer step.

The gradient accumulation callback may set this to False if we are accumulating gradients and have not reached the accumulation steps. Note, this only applies to CPU/GPU runs.

*property* loop\_:  [cerebras.modelzoo.trainer.callbacks.loop.LoopCallback](generated/cerebras.modelzoo.trainer.callbacks.html#cerebras.modelzoo.trainer.callbacks.LoopCallback "cerebras.modelzoo.trainer.callbacks.loop.LoopCallback")\_[#](#cerebras.modelzoo.Trainer.loop "Permalink to this definition")

Returns the default loop settings.

*property* checkpoint\_:  [cerebras.modelzoo.trainer.callbacks.checkpoint.Checkpoint](generated/cerebras.modelzoo.trainer.callbacks.html#cerebras.modelzoo.trainer.callbacks.Checkpoint "cerebras.modelzoo.trainer.callbacks.checkpoint.Checkpoint")\_[#](#cerebras.modelzoo.Trainer.checkpoint "Permalink to this definition")

Returns the checkpoint callback.

*property* logging\_:  [cerebras.modelzoo.trainer.callbacks.checkpoint.Checkpoint](generated/cerebras.modelzoo.trainer.callbacks.html#cerebras.modelzoo.trainer.callbacks.Checkpoint "cerebras.modelzoo.trainer.callbacks.checkpoint.Checkpoint")\_[#](#cerebras.modelzoo.Trainer.logging "Permalink to this definition")

Returns the logging callback.

*property* logger\_:  logging.Logger\_[#](#cerebras.modelzoo.Trainer.logger "Permalink to this definition")

Returns the Trainer’s Python logger object.

*property* is\_log\_step\_:  bool\_[#](#cerebras.modelzoo.Trainer.is_log_step "Permalink to this definition")

Returns True if the current step is a log step.

*property* is\_first\_iteration\_:  bool\_[#](#cerebras.modelzoo.Trainer.is_first_iteration "Permalink to this definition")

Returns True if the executor is on its first iteration.

*property* is\_final\_iteration\_:  bool\_[#](#cerebras.modelzoo.Trainer.is_final_iteration "Permalink to this definition")

Returns True if the executor is on its final iteration.

*property* is\_tracing\_:  bool\_[#](#cerebras.modelzoo.Trainer.is_tracing "Permalink to this definition")

Returns True if we are currently tracing the model.

*final* log\_metrics\_in\_step\_closure(*\*\*kwargs*)[\[source\]](../../../_modules/cerebras/modelzoo/trainer/trainer.html#Trainer.log_metrics_in_step_closure)[#](#cerebras.modelzoo.Trainer.log_metrics_in_step_closure "Permalink to this definition")

Log the given kwargs inside a step closure.

*final* log\_metrics(*\*\*kwargs*)[\[source\]](../../../_modules/cerebras/modelzoo/trainer/trainer.html#Trainer.log_metrics)[#](#cerebras.modelzoo.Trainer.log_metrics "Permalink to this definition")

Log the given kwargs to all loggers.

Example usage:

trainer.log\_metrics(loss=loss.item())

Copy to clipboard

Parameters

**kwargs** – The key-value pairs to log.

*final* name\_scope(*name*)[\[source\]](../../../_modules/cerebras/modelzoo/trainer/trainer.html#Trainer.name_scope)[#](#cerebras.modelzoo.Trainer.name_scope "Permalink to this definition")

Append name to the trainer’s name scope stack whilst inside the context.

Parameters

**name** (*str*) – The name to append to the name scope stack.

*property* name\_scope\_path\_:  str\_[#](#cerebras.modelzoo.Trainer.name_scope_path "Permalink to this definition")

Returns the current name scope path.

This is the the name scope stack joined by ‘/’.

*final* get\_val\_dataloader\_scope(*val\_dataloader*)[\[source\]](../../../_modules/cerebras/modelzoo/trainer/trainer.html#Trainer.get_val_dataloader_scope)[#](#cerebras.modelzoo.Trainer.get_val_dataloader_scope "Permalink to this definition")

Get the name scope for the given val dataloader

*final* training\_step(*batch*)[\[source\]](../../../_modules/cerebras/modelzoo/trainer/trainer.html#Trainer.training_step)[#](#cerebras.modelzoo.Trainer.training_step "Permalink to this definition")

Run a single training step on the given batch..

Note that if retrace is off, content of this method will only run on the first iteration. So any inputs to this method must either be non-changing or torch tensors.

Parameters

* **batch** – The batch of data to train on.

* **batch\_idx** – The index of the batch in the dataloader.

Returns

A dictionary containing the loss and any other outputs.

Return type

*Dict*\[str, *Any*]

*final* forward(*batch*)[\[source\]](../../../_modules/cerebras/modelzoo/trainer/trainer.html#Trainer.forward)[#](#cerebras.modelzoo.Trainer.forward "Permalink to this definition")

Run the forward pass on the given batch.

Parameters

**batch** – The batch of data to run the forward pass on.

Returns

A dictionary containing the loss and any other outputs.

Return type

*Dict*\[str, *Any*]

*final* backward(*outputs*)[\[source\]](../../../_modules/cerebras/modelzoo/trainer/trainer.html#Trainer.backward)[#](#cerebras.modelzoo.Trainer.backward "Permalink to this definition")

Run the backward pass on the given loss.

Parameters

**outputs** (*dict*) – The outputs of the model. Expect key ‘loss’ to be present.

*final* optimizer\_step()[\[source\]](../../../_modules/cerebras/modelzoo/trainer/trainer.html#Trainer.optimizer_step)[#](#cerebras.modelzoo.Trainer.optimizer_step "Permalink to this definition")

Run the optimizer step.

*final* optimizer\_zero\_grad()[\[source\]](../../../_modules/cerebras/modelzoo/trainer/trainer.html#Trainer.optimizer_zero_grad)[#](#cerebras.modelzoo.Trainer.optimizer_zero_grad "Permalink to this definition")

Zero the gradients of the optimizer.

*final* schedulers\_step()[\[source\]](../../../_modules/cerebras/modelzoo/trainer/trainer.html#Trainer.schedulers_step)[#](#cerebras.modelzoo.Trainer.schedulers_step "Permalink to this definition")

Step all the schedulers.

on\_exception(*hook*)[\[source\]](../../../_modules/cerebras/modelzoo/trainer/trainer.html#Trainer.on_exception)[#](#cerebras.modelzoo.Trainer.on_exception "Permalink to this definition")

Context manager to handle exceptions in the given hook.

Parameters

**hook** – The hook to handle exceptions for.

*final* fit(*train\_dataloader*, *val\_dataloader=None*, *ckpt\_path=`<object object>`*)\[source][#](.././_modules/cerebras/modelzoo/trainer/trainer.html#Trainer.fit)[#](#cerebras.modelzoo.Trainer.fit "Permalink to this definition")

Complete a full training run on the given train and validation dataloaders.

Parameters

* **train\_dataloader** (*cerebras.appliance.log.named\_class\_logger*) – The training dataloader.

* **val\_dataloader** (*Optional*\_\[**Union**\[**cerebras.appliance.log.named\_class\_logger**,\_ *List*\_\[**cerebras.appliance.log.named\_class\_logger**]**]**]\_) –The validation dataloader.If provided, validation is run every eval\_frequency steps as defined in the loop callback.If not provided, only training is run.If a list of dataloaders is provided, then each dataloader is validated in sequence.

* **ckpt\_path** (*Optional*\_\[**str**]\_) – The path to the checkpoint to load before starting training. If not provided and autoload\_last\_checkpoint is True, then the latest checkpoint is loaded

validation\_step[#](#cerebras.modelzoo.Trainer.validation_step "Permalink to this definition")

*final* validate(*val\_dataloader=None*, *ckpt\_path=`<object object>`*, *loop=None*)[\[source\]](../../../_modules/cerebras/modelzoo/trainer/trainer.html#Trainer.validate)[#](#cerebras.modelzoo.Trainer.validate "Permalink to this definition")

Complete a full validation run on the validation dataloader.

Parameters

* **val\_dataloader** (*Optional*\_\[**cerebras.appliance.log.named\_class\_logger**]\_) –The validation dataloader. If a list of dataloaders is provided, then each dataloader is> validated in sequence.

* **ckpt\_path** (*Optional*\_\[**str**]\_) – The path to the checkpoint to load before starting validation. If not provided and autoload\_last\_checkpoint is True, then the latest checkpoint is loaded.

* **loop** (*Optional*\_\[[*cerebras.modelzoo.trainer.callbacks.loop.ValidationLoop*](generated/cerebras.modelzoo.trainer.callbacks.html#cerebras.modelzoo.trainer.callbacks.ValidationLoop "cerebras.modelzoo.trainer.callbacks.loop.ValidationLoop")]\_) – The loop callback to use for validation. If not provided, the default loop is used. If provided, it must be an instance of ValidationLoop. Note, this should only be provided if the loop callback provided in the constructor is not sufficient.

*final* validate\_all(*val\_dataloaders=None*, *ckpt\_paths=`<object object>`*, *loop=None*)[\[source\]](../../../_modules/cerebras/modelzoo/trainer/trainer.html#Trainer.validate_all)[#](#cerebras.modelzoo.Trainer.validate_all "Permalink to this definition")

Runs all upstream and downstream validation permutations.

for ckpt\_path in ckpt\_paths:
for val\_dataloader in val\_dataloaders:
trainer.validate(val\_dataloader, ckpt\_path)

\# run downstream validation
run\_validation(...)

Copy to clipboard

Parameters

* **val\_dataloaders** (*Optional*\_\[**Union**\[**cerebras.appliance.log.named\_class\_logger**,\_ *List*\_\[**cerebras.appliance.log.named\_class\_logger**]**]**]\_) – A list of validation dataloaders to run validation on.

* **ckpt\_paths** (*Optional*\_\[**Union**\[**List**\[**str**]**,\_ *str*\_]**]\_) – A list of checkpoint paths to run validation on. Each checkpoint path must be a path to a checkpoint file, or a glob pattern.

* **loop** (*Optional*\_\[[*cerebras.modelzoo.trainer.callbacks.loop.ValidationLoop*](generated/cerebras.modelzoo.trainer.callbacks.html#cerebras.modelzoo.trainer.callbacks.ValidationLoop "cerebras.modelzoo.trainer.callbacks.loop.ValidationLoop")]\_) – The validation loop to use for validation. If not provided, then the default loop is used.

*final* save\_checkpoint()[\[source\]](../../../_modules/cerebras/modelzoo/trainer/trainer.html#Trainer.save_checkpoint)[#](#cerebras.modelzoo.Trainer.save_checkpoint "Permalink to this definition")

Save a checkpoint at the current global step.

The checkpoint state dict is constructed by various callbacks that implement the on\_save\_checkpoint method.

*final* load\_checkpoint(*ckpt\_path=None*)[\[source\]](../../../_modules/cerebras/modelzoo/trainer/trainer.html#Trainer.load_checkpoint)[#](#cerebras.modelzoo.Trainer.load_checkpoint "Permalink to this definition")

Load a checkpoint from the given path.

The checkpoint state dict is loaded and processed by various callbacks that implement the on\_load\_checkpoint method.

Parameters

**ckpt\_path** (*Optional*\_\[**str**]\_) – The path to the checkpoint to load If not provided and autoload\_last\_checkpoint is True, then the latest checkpoint is loaded

# Cerebras Model Zoo Callbacks API[#](#cerebras-model-zoo-callbacks-api "Permalink to this headline")

|                                                                                                                                                                              |                                                                                                                                                                                                                     |
| ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [`cerebras.modelzoo.trainer.callbacks`](generated/cerebras.modelzoo.trainer.callbacks.html#module-cerebras.modelzoo.trainer.callbacks "cerebras.modelzoo.trainer.callbacks") | This module contains the base Callback class as well as a number of core callbacks directly invoked by the Trainer as well as other optional callbacks that can be used to extend the functionality of the Trainer. |

# Cerebras Model Zoo Extensions API[#](#cerebras-model-zoo-extensions-api "Permalink to this headline")

|                                                                                                                                                                                  |                                                                     |
| -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------- |
| [`cerebras.modelzoo.trainer.extensions`](generated/cerebras.modelzoo.trainer.extensions.html#module-cerebras.modelzoo.trainer.extensions "cerebras.modelzoo.trainer.extensions") | This module contains integrations of external tools to the Trainer. |

# Cerebras Model Zoo Loggers API[#](#cerebras-model-zoo-loggers-api "Permalink to this headline")

|                                                                                                                                                                      |                                                                                       |
| -------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------- |
| [`cerebras.modelzoo.trainer.loggers`](generated/cerebras.modelzoo.trainer.loggers.html#module-cerebras.modelzoo.trainer.loggers "cerebras.modelzoo.trainer.loggers") | This module contains the base Logger class as well as a few useful Logger subclasses. |
