float32
float32
and one of:
cbfloat16
bfloat16
float16
torch.where
for conditional logic. The torch.where
function acts as a tensor-based conditional, equivalent to an if statement but represented within the graph. Avoid any flow control that could cause a different graph. Only tensor operations are traced during execution.
So the following code snippet:
input1.max()
is not known but is required to compute the Python conditional. Hence, this will lead to the following tracing error:
tensor.to("cpu")
, and calling tensor.item()
.
torch
operations expect scalar values as arguments, but passing traced tensors to them can lead to unexpected behavior. Always explicitly convert traced tensors to scalars before using them in these contexts. Try restructuring code to avoid traced tensors in scalar operations whenever possible, as explicit conversions can impact performance.
Torch operations such as torch.add
(alpha argument expects a scalar, and not a traced tensor) and torch.addcdiv
(which expects a scalar) can offer performance benefits. They leverage fused-multiply-add or similar techniques for efficiency. However, remember that they expect true scalar values for certain arguments (e.g., value, alpha):
SyncTensor
called outside of MarkStep error and a random value will be used instead. In addition, if this was the only “use” of the lrs tensor, walking the dependency graph of all traced values will show no computation depending on the value of lrs and the entire set of tensor operations will not be lowered and compiled down.
state[param]["momentum"]
on step 1 should be fed as the initial value for state[param]["momentum"]
on step 2 etc. These values are “weights” and kept resident in system memory of some kind (on wafer in pipeline mode, or in the MemoryX/weight hosts in weight streaming). In this case (even disregarding the python flow control), the state[param]["momemtum"]
has no identity before this first step. Executing the operations recorded by this trace would re-initialize momentum to zero on every single step, with the updated tensor result having no place to be stored (it would be treated as another model output like loss or predictions).
preinitialize()
function ensures that all stateful tensors within the optimizer, typically stored in its state_dict, are explicitly initialized before the tracing process begins. This is crucial for Cerebras WSC compatibility. The Cerebras framework automatically calls preinitialize()
at the appropriate time, simplifying compliance. The function can also be called within the optimizer’s _init_
method for traditional GPU-based training, providing a flexible implementation.
The Cerebras optimizer wrapper injects a traced tensor for param_group["lr"]
to dynamically implement learning rate schedules based on a traced global_step tensor
. Passing these traced tensors to operations that expect scalar values can lead to unexpected behavior or errors due to implicit decay attempts.
Refactor static graphs by exploring alternative tensor-based implementations or restructuring to avoid non-static graph constructs.
cerebras.pytorch.optim.AdamBase
for an example of this.
CSConfig
configurations, such as num_csx
and num_workers_per_csx
are global across all executions.