Customizing the Trainer with Callbacks
The Trainer
class was designed to be easily extendable using Callback
classes. The Trainer
exposes a number of hooks which can be overriden using a Callback
.
On this page, you will learn about the basic Callback
mechanism. By the end you should be able to write and use your own custom Callback
.
Prerequisites
Please ensure that you have read through the Cerebras Model Zoo Trainer Overview beforehand. The rest of this page assumes that you already have at least a cursory understanding of what the Cerebras Model Zoo Trainer is and how to use the python API.
Callbacks
The callback mechanism is the backbone of the Trainer
’s implementation. A lot of the heavy lifting in the Trainer is actually done by various Core Callbacks.
In general, the [Callback
] mechanism exposes a number of useful hooks that allow you to inject certain behaviour into the [Trainer
]. These hooks include (but are not limited to)
-
setup
-
on_{fit,train,validate}_{start,end}
-
on_{train,validate}_batch_{start,end}
-
on_{after,before}_{forward,backward}
-
on_{after,before}_optimizer_{step,zero_grad}
-
on_{after,before}_scheduler_step
-
on_{save,load}_checkpoint
-
on_after_save_checkpoint
-
on_before_load_checkpoint
The following pseudocode describes the structure of the fit
call and where the various hooks get called.
For a comprehensive list of all supported hooks (as well as the arguments they accept), see the API docs for the Callback
class.
Pre-packaged Callbacks
There are many callbacks that come pre-packaged inside of the Model Zoo. See Add-on Callbacks for a complete list of all the callbacks available out-of-the-box in the Model Zoo
You can use any number of them to enhance the Trainer for your run.
For example,
Global Callbacks
Any callback can be registered globally so that all Trainer
instances know about it and will invoke that callback’s hooks.
There are two ways to globally register a callback. The first way is to treat the callback as a context manager. For example,
In the above example, while within the CheckLoss
’s context, all trainer fit
calls inside the context will check the loss values that come out of the model.
The other way to register a callback is to call :py:function:~cerebras.modelzoo.trainer.callbacks.register\_global\_callback
.
For example,
In the above example, all trainer fit
calls inside the context will check the loss values that come out of the model.
:py:function:~cerebras.modelzoo.trainer.callbacks.register\_global\_callback
returns a removeable handle object that can be used to remove the added callback by calling handle.remove()
Callback Ordering
The Trainer
is comprised of many different callbacks that all serve to enhance its functionality`.
All of these callbacks share common hooks. These hooks must be called in a specific order. The order in which callbacks are invoked is as follows:
-
Core Callbacks: The callbacks that implement the most fundamental behaviour of the
Trainer
get called first. -
User-defined callbacks: The callbacks that are passed into the
callbacks
argument of theTrainer
’s constructor are called next. -
Global callbacks: Finally, the callbacks that are registered globally are called.
For example,
Let’s consider the on_fit_start
hook. Between the three callbacks that are highlighed in the above example, the order that the callbacks’s on_fit_start
hook is invoked is as follows:
-
TrainingLoop.on_fit_start
: AsTrainingLoop
is a core callback. -
ComputeNorm.on_fit_start
: AsComputeNorm
was passed into the Trainer’s constructor. -
CheckLoss.on_fit_start
: As it is a globally registered callback.
Writing a Custom Callback
To write your own custom callback class, all you need to do is inherit from the base Callback
class and override the hooks that you need.
For example, let’s implement a simple callback that scales the loss value by some constant value before we call loss.backward()
That is all there is to it. This callback can now be used inside the Trainer
as follows:
In order for the callback class to exist in the Python global namespace, the Python interpreter must have seen it at some point. Implementing your custom callback in the run.py
or in the same file as the model class are two ways to ensure that the callback is seen by the Python interpreter and loaded into the Python global namespace.
Conclusion
By this point, you should have a cursory understanding of how Callbacks can be used to enhance the Trainer. There are many useful callbacks that come pre-packaged inside the ModelZoo. If there is some functionality that you need that is not covered, you should be confortable with writing your own to implement that functionality.