cerebras.pytorch.amp
Automatic mixed precision
The following classes and subclasses are designed to facilitate automatic mixed precision on the Cerebras Wafer Scale Cluster
GradScaler
#
class cerebras.pytorch.amp.GradScaler
(loss_scale=None, init_scale=None, steps_per_increase=None, min_loss_scale=None, max_loss_scale=None, overflow_tolerance=0.0, max_gradient_norm=None)
Faciliates mixed precision training and DLS, DLS + GCC
For more details please see docs for amp.initialize.
Parameters:
-
loss_scale (Union[str,_float]] – If loss_scale == “dynamic”, then configure dynamic loss scaling. Otherwise, it is the loss scale value used in static loss scaling.
-
init_scale (float) – The initial loss scale value if loss_scale == “dynamic”
-
steps_per_increase (int) – The number of steps after which to increase the loss scaling condition
-
min_loss_scale (float) – The minimum loss scale value that can be chosen by dynamic loss scaling
-
max_loss_scale (float) – The maximum loss scale value that can be chosen by dynamic loss scaling
-
overflow_tolerance (float) – The maximum fraction of steps involving infinite or undefined values in the gradient we allow. We reduce the loss scale if the tolerance is exceeded
-
max_gradient_norm (float) – The maximum gradient norm to use for global gradient clipping Only applies in the DLS + GCC case. If GCC is not enabled, then this parameter has no effect
state_dict
(destination=None)
Returns a dictionary containing the state to be saved to a checkpoint
load_state_dict
(state_dict)
Loads the state dictionary into the current params
scale
(loss)
Scales the loss in preparation of the backwards pass
get_scale
()
Return the loss scale
unscale_
(optimizer)
Unscales the optimizer’s params gradients inplace
step_if_finite
(optimizer, *args, **kwargs)
Directly conditionalize the call to optimizer.step(*args, **kwargs) but only if this GradScaler detected finite grads.
Parameters:
-
optimizer (torch.optim.Optimizer) – Optimizer that applies the gradients.
-
args – Any arguments passed to the optimizer.step() call.
-
kwargs – Any keyword arguments passed to the optimizer.step() call.
**Returns:**The result of optimizer.step()
clip_gradients_and_return_isfinite
*(optimizers)
Clip the optimizer’s params’s gradients and return whether or not the norm is finite
step
(optimizer, *args, **kwargs)[source]#
Step carries out the following two operations: 1. Internally invokes unscale_(optimizer)
(unless unscale_ was
explicitly called for optimizer
earlier in the iteration). As part of the unscale_, gradients are checked for infs/NaNs.
- Invokes
optimizer.step()
using the unscaled gradients. Ensure that previous optimizer state or params carry over if we encounter NaNs in the gradients.
*args
and **kwargs
are forwarded to optimizer.step()
. Returns the return value of optimizer.step(*args, **kwargs)
. :param optimizer: Optimizer that applies the gradients. :type optimizer: cerebras.pytorch.optim.Optimizer :param args: Any arguments. :param kwargs: Any keyword arguments.
update_scale`(optimizers)
Update the scales of the optimizers
####update
(new_scale=None)[source]#
Update the gradient scalar after all optimizers have been stepped
set_half_dtype
cerebras.pytorch.amp.set_half_dtype
(value)
Sets the underlying 16-bit floating point dtype to use.
Parameters:
value (Union[Literal[‘float16’, ‘bfloat16’, ‘cbfloat16’], torch.dtype])– Either a 16-bit floating point torch dtype or one of “float16”, “bfloat16”, or “cbfloat16” string.
Returns: The proxy torch dtype to use for the model. For dtypes that have a torch representation, this returns the same as value passed in. Otherwise, it returns a proxy dtype to use in the model. On CSX, these proxy dtypes are automatically and transparently converted to the real dtype during compilation.
Return type: torch.dtype
By default, automatic mixed precision uses float16
. If you want to use cbfloat16
or bfloat16
instead of float16
, call this function.
Example usage:
optimizer_step
cerebras.pytorch.amp.optimizer\_step
(loss, optimizer, grad_scaler, max_gradient_norm=None, max_gradient_value=None)
Performs loss scaling, gradient scaling and optimizer step
Parameters:
-
loss (torch.Tensor) – The loss value to scale. loss.backward should be called before this function
-
optimizer (cerebras.pytorch.optim.optimizer.Optimizer) – The optimizer to step
-
grad_scaler (cerebras.pytorch.amp.grad_scaler.GradScaler) – The gradient scaler to use to scale the parameter gradients
-
max_gradient_norm (Optional[float]) – the max gradient norm to use for gradient clipping
-
max_gradient_value (Optional[float]) – the max gradient value to use for gradient clipping
Example usage: