Automatic mixed precision

The following classes and subclasses are designed to facilitate automatic mixed precision on the Cerebras Wafer Scale Cluster

GradScaler#

class cerebras.pytorch.amp.GradScaler(loss_scale=None, init_scale=None, steps_per_increase=None, min_loss_scale=None, max_loss_scale=None, overflow_tolerance=0.0, max_gradient_norm=None)

[source]#

Faciliates mixed precision training and DLS, DLS + GCC

For more details please see docs for amp.initialize.

Parameters:

  • loss_scale (Union[str,_float]] – If loss_scale == “dynamic”, then configure dynamic loss scaling. Otherwise, it is the loss scale value used in static loss scaling.

  • init_scale (float) – The initial loss scale value if loss_scale == “dynamic”

  • steps_per_increase (int) – The number of steps after which to increase the loss scaling condition

  • min_loss_scale (float) – The minimum loss scale value that can be chosen by dynamic loss scaling

  • max_loss_scale (float) – The maximum loss scale value that can be chosen by dynamic loss scaling

  • overflow_tolerance (float) – The maximum fraction of steps involving infinite or undefined values in the gradient we allow. We reduce the loss scale if the tolerance is exceeded

  • max_gradient_norm (float) – The maximum gradient norm to use for global gradient clipping Only applies in the DLS + GCC case. If GCC is not enabled, then this parameter has no effect

grad_scaler = cstorch.amp.GradScaler(loss_scale="dynamic")

loss: torch.Tensor = ...

optimizer.zero_grad()
# Scale the loss before calling the backward pass
grad_scaler.scale(loss).backward()

# Unscales the gradients of optimizer's assigned params in-place
# to facilitate things like gradient clipping
grad_scaler.unscale_(optimizer)

# Global gradient clipping
torch.nn.utils.clip_grad_norm_(
    model.parameters(),
    1.0,  # max gradient norm
)

# Step the optimizer using the grad scaler
grad_scaler.step(optimizer)

# update the grad scaler once all optimizers have been stepped
grad_scaler.update()

state_dict(destination=None)

[source]#

Returns a dictionary containing the state to be saved to a checkpoint

load_state_dict(state_dict)

[source] #

Loads the state dictionary into the current params

scale(loss)

[source]#

Scales the loss in preparation of the backwards pass

get_scale()

[source]#

Return the loss scale

unscale_(optimizer)

[source]#

Unscales the optimizer’s params gradients inplace

step_if_finite(optimizer, *args, **kwargs)

[source]#

Directly conditionalize the call to optimizer.step(*args, **kwargs) but only if this GradScaler detected finite grads.

Parameters:

  • optimizer (torch.optim.Optimizer) – Optimizer that applies the gradients.

  • args – Any arguments passed to the optimizer.step() call.

  • kwargs – Any keyword arguments passed to the optimizer.step() call.

**Returns:**The result of optimizer.step()

clip_gradients_and_return_isfinite*(optimizers)

[source]#

Clip the optimizer’s params’s gradients and return whether or not the norm is finite

step(optimizer, *args, **kwargs)[source]#

Step carries out the following two operations: 1. Internally invokes unscale_(optimizer) (unless unscale_ was

explicitly called for optimizer earlier in the iteration). As part of the unscale_, gradients are checked for infs/NaNs.

  1. Invokes optimizer.step() using the unscaled gradients. Ensure that previous optimizer state or params carry over if we encounter NaNs in the gradients.

*args and **kwargs are forwarded to optimizer.step(). Returns the return value of optimizer.step(*args, **kwargs). :param optimizer: Optimizer that applies the gradients. :type optimizer: cerebras.pytorch.optim.Optimizer :param args: Any arguments. :param kwargs: Any keyword arguments.

update_scale`(optimizers)

[source]#

Update the scales of the optimizers

####update(new_scale=None)[source]#

Update the gradient scalar after all optimizers have been stepped

set_half_dtype

cerebras.pytorch.amp.set_half_dtype(value)

[source]#

Sets the underlying 16-bit floating point dtype to use.

Parameters:

value (Union[Literal[‘float16’, ‘bfloat16’, ‘cbfloat16’], torch.dtype])– Either a 16-bit floating point torch dtype or one of “float16”, “bfloat16”, or “cbfloat16” string.

Returns: The proxy torch dtype to use for the model. For dtypes that have a torch representation, this returns the same as value passed in. Otherwise, it returns a proxy dtype to use in the model. On CSX, these proxy dtypes are automatically and transparently converted to the real dtype during compilation.

Return type: torch.dtype

By default, automatic mixed precision uses float16. If you want to use cbfloat16 or bfloat16 instead of float16, call this function.

Example usage:

cstorch.amp.set_half_dtype("cbfloat16")

optimizer_step

cerebras.pytorch.amp.optimizer\_step(loss, optimizer, grad_scaler, max_gradient_norm=None, max_gradient_value=None)

[source]#

Performs loss scaling, gradient scaling and optimizer step

Parameters:

  • loss (torch.Tensor) – The loss value to scale. loss.backward should be called before this function

  • optimizer (cerebras.pytorch.optim.optimizer.Optimizer) – The optimizer to step

  • grad_scaler (cerebras.pytorch.amp.grad_scaler.GradScaler) – The gradient scaler to use to scale the parameter gradients

  • max_gradient_norm (Optional[float]) – the max gradient norm to use for gradient clipping

  • max_gradient_value (Optional[float]) – the max gradient value to use for gradient clipping

Example usage:

cstorch.amp.optimizer_step(
    loss,
    optimizer,
    grad_scaler,
    max_gradient_norm=1.0,
)