Gradient Scaling
Gradient scaling can improve convergence when training models with float16 gradients by minimizing gradient underflow. Please see the PyTorch docs for a more detailed explanation.
To facilitate gradient scaling, we introduce a Cerebras-compliant implementation of the AMP GradScaler class found in core PyTorch at cerebras.pytorch.amp.GradScaler
. For example:
It is designed to be as similar as possible to the API of the CUDA AMP GradScaler class.
Its usage is identical to the usage of the CUDA AMP GradScaler:
Using Automatic Mixed Precision
By default, automatic mixed precision uses float16
. If you want to use cbfloat16
or bfloat16
instead of float16
, call cerebras.pytorch.amp.set_half_dtype
, e.g.
Using a Helper Function for Gradient Scaling
We introduce an optional helper function cerebras.pytorch.amp.optimizer_step
to take care of the details of gradient scaling
It is useful for quickly constructing typical examples that use gradient scaling without needing to type up the details or worry about whether the grad scaler is being used correctly.
This is entirely optional and only covers the basic gradient scaler use case. For more complicated use cases, the grad scaler object must be used explicitly.