Trainer
class was designed to be easily extendable using Callback
classes. The Trainer
exposes a number of hooks which can be overriden using a Callback
.
On this page, you will learn about the basic Callback
mechanism. By the end you should be able to write and use your own custom Callback
.
Trainer
’s implementation. A lot of the heavy lifting in the Trainer is actually done by various Core Callbacks.
In general, the [Callback
] mechanism exposes a number of useful hooks that allow you to inject certain behaviour into the [Trainer
]. These hooks include (but are not limited to)
setup
on_{fit,train,validate}_{start,end}
on_{train,validate}_batch_{start,end}
on_{after,before}_{forward,backward}
on_{after,before}_optimizer_{step,zero_grad}
on_{after,before}_scheduler_step
on_{save,load}_checkpoint
on_after_save_checkpoint
on_before_load_checkpoint
fit
call and where the various hooks get called.
Callback
class.
Trainer
instances know about it and will invoke that callback’s hooks.
There are two ways to globally register a callback. The first way is to treat the callback as a context manager. For example,
CheckLoss
’s context, all trainer fit
calls inside the context will check the loss values that come out of the model.
The other way to register a callback is to call :py:function:~cerebras.modelzoo.trainer.callbacks.register\_global\_callback
.
For example,
fit
calls inside the context will check the loss values that come out of the model.
:py:function:~cerebras.modelzoo.trainer.callbacks.register\_global\_callback
returns a removeable handle object that can be used to remove the added callback by calling handle.remove()
Trainer
is comprised of many different callbacks that all serve to enhance its functionality`.
All of these callbacks share common hooks. These hooks must be called in a specific order. The order in which callbacks are invoked is as follows:
Trainer
get called first.
callbacks
argument of the Trainer
’s constructor are called next.
on_fit_start
hook. Between the three callbacks that are highlighed in the above example, the order that the callbacks’s on_fit_start
hook is invoked is as follows:
TrainingLoop.on_fit_start
: As TrainingLoop
is a core callback.
ComputeNorm.on_fit_start
: As ComputeNorm
was passed into the Trainer’s constructor.
CheckLoss.on_fit_start
: As it is a globally registered callback.
Callback
class and override the hooks that you need.
For example, let’s implement a simple callback that scales the loss value by some constant value before we call loss.backward()
Trainer
as follows:
run.py
or in the same file as the model class are two ways to ensure that the callback is seen by the Python interpreter and loaded into the Python global namespace.