Skip to main content

Documentation Index

Fetch the complete documentation index at: https://training-docs.cerebras.ai/llms.txt

Use this file to discover all available pages before exploring further.

On the Cerebras Wafer-Scale Cluster, the client (your Python script) and the server (the cluster hardware) run asynchronously — the server does not wait for the client between steps. This prevents client-side bottlenecks like disk I/O or networking from slowing down training. However, this means that when your client code requests a tensor value (such as the loss), the server may not have computed it yet. For example, compilation happens during the first iteration of the training loop, so no tensors are available until compilation finishes and execution begins. To handle this, we introduce the concept of a step closure via the step_closure decorator, e.g.
@cstorch.step_closure
def closure(loss):
    print(f"Loss: {loss}")
Tensors that are passed into a “step closure” are fetched from the server and their value is materialized before the closure is actually called. If the tensor is not yet available; it waits until the server “catches up” with the current step, and the tensor value is available to be fetched before actually calling the closure.

Example Usage

You can call step closures either inside the traced step function or outside of it:
@cstorch.step_closure
def check_nan(loss):
    assert not torch.isnan(loss).any()

@cstorch.step_closure
def print_loss(loss):
    print(f"Loss: {loss}")

@cstorch.trace
def training_step(inputs, targets):
     outputs = compiled_model(inputs)
     loss = loss_fn(outputs, targets)

     check_nan(loss)

     loss.backward()
     optimizer.step()
     optimizer.zero_grad()

     return loss

for inputs, targets in executor:
    loss = training_step(inputs, targets)
    print_loss(loss)