cerebras.pytorch.amp.GradScaler
. For example:
float16
. If you want to use cbfloat16
or bfloat16
instead of float16
, call cerebras.pytorch.amp.set_half_dtype
, e.g.
cerebras.pytorch.amp.optimizer_step
to take care of the details of gradient scaling