Our typical workflow involves using a training script provided in the Cerebras Model Zoo. However, if that training loop is insufficient for your model needs, you may write your own training loop using the Cerebras PyTorch API.

The following steps will only take you through the absolute minimum code required to run a simple, small model on the Cerebras Wafer Scale Cluster. To extend the script to feature things like learning rate scheduling, gradient scaling, etc. please continue to the further reading section to learn more about these topics.

Prerequisites

You have installed the cerebras.pytorch package in your environment.

Validate the Package Installation

Run the following command to validate that cerebras.pytorch package is installed correctly:

import cerebras.pytorch as cstorch

From here on, we will be using cstorch as the alias for cerebras.pytorch

Configure the Wafer Scale Cluster

To configure the Cerebras Wafer-Scale cluster, construct a ClusterConfig object, then use that to construct a backend object:

cluster_config = cstorch.distributed.ClusterConfig(
    mgmt_address=mgmt_address,
    max_wgt_servers=1,
    max_act_per_csx=1,
    num_workers_per_csx=1,
)

backend = cstorch.backend(
    "CSX",
    cluster_config=cluster_config,
)

See the class documentation for ClusterConfig to view all configurable options.

Most options have reasonable defaults and do not need to be changed.

Define Your Model

When using the Cerebras PyTorch API, define your model in the same way you would in a Vanilla PyTorch workflow:

import torch
import torch.nn.functional as F


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(784, 256)
        self.fc2 = torch.nn.Linear(256, 10)

    def forward(self, inputs):
        inputs = torch.flatten(inputs, 1)
        outputs = F.relu(self.fc1(inputs))
        return F.relu(self.fc2(outputs))


model = Model()

Weight initialization for large models can cause out-of-memory errors. Not only that, but initializing extremely large models eagerly can be very slow. See the page on Efficient weight initialization to see how to work around this issue.

Compile Your Model

Once the model has been instantiated, compile the model by calling the cerebras.pytorch.compile, for example:

compiled_model = cstorch.compile(model, backend)

The call to cstorch.compile doesn’t actually compile the model. Similar to torch.compile it only prepares the model for compilation. Compilation only happens after the first iteration, once the input shapes are known.

Optimize Model Parameters

To optimize model parameters using the Cerebras Wafer-Scale cluster, you must use a Cerebras-compliant optimizer. There are exact drop-in replacements for all commonly used optimizers available in cerebras.pytorch.optim, for example:

optimizer = cstorch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

If you are interested in writing your own Cerebras custom-compliant optimizer, see the page on Writing Custom Optimizers.

Dataloaders

To send data to the Wafer-Scale cluster, wrap your PyTorch dataloader with cerebras.pytorch.utils.data.DataLoader.

For example:

def get_torch_dataloader(batch_size, train):
    from torchvision import datasets, transforms

    train_dataset = datasets.MNIST(
        "./data",
        train=train,
        download=True,
        transform=transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        ),
        target_transform=transforms.Lambda(
            lambda x: torch.as_tensor(x, dtype=torch.int32)
        ),
    )

    return torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True
    )


training_dataloader = cstorch.utils.data.DataLoader(
    get_torch_dataloader, batch_size=64, train=True
)

The Cerebras PyTorch dataloader requires a callable that generates a PyTorch dataloader. This approach ensures that each worker can independently create its own dataloader instance, optimizing distributed parallelism.

Define the Training Step

To execute a single training iteration on the Wafer-Scale Cluster, you first need to capture all operations intended to run on the cluster. Do this by defining a function that includes all actions for a single training iteration and decorating it with cerebras.pytorch.trace.

For example:

loss_fn = torch.nn.CrossEntropyLoss()

@cstorch.trace
def training_step(inputs, targets):
    outputs = compiled_model(inputs)
    loss = loss_fn(outputs, targets)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    return loss

This function is traced and sent to the cluster for compilation and execution.

Define an Execution

To program an execution run on the Cerebras Wafer-Scale cluster, define an instance of the cerebras.pytorch.utils.data.DataExecutor.

For example:

train_executor = cstorch.utils.data.DataExecutor(
    training_dataloader,
    num_steps=100,
    checkpoint_steps=50,
)

It takes in the Cerebras PyTorch dataloader that will be used during the run, the total number of steps to run for, as well as the interval at which checkpoints will be taken.

Train Your Model

Once the above is defined, you can iterate through the executor to train your model.

@cstorch.step_closure
def print_loss(mode, loss: torch.Tensor, step: int):
    print(f"{mode} Loss {step}: {loss.item()}")


@cstorch.checkpoint_closure
def save_checkpoint(step):
    cstorch.save(
        {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        },
        f"checkpoint_{step}.mdl",
    )

global_step = 0

for inputs, targets in train_executor:
    loss = training_step(inputs, targets)
    print_loss("Training", loss, global_step)
    global_step += 1
    save_checkpoint(global_step)
  • Notice how the loss was passed into a function decorated by step_closure. This is required to retrieve the loss value from the cluster before it can be used. See the page on step closures for more details.

  • Also, notice how checkpoints are saved inside a function decorated by checkpoint_closure. This is required to retrieve the model weights and optimizer state back from the cluster before it can be saved. Please see the page on saving checkpoints.

Putting It All Together

Combining all of the above steps, you create a super minimal training script for a simple, fully connected model training on the MNIST dataset:

# Import the Cerebras PyTorch module
import cerebras.pytorch as cstorch

# Define a model
import torch
import torch.nn.functional as F


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(784, 256)
        self.fc2 = torch.nn.Linear(256, 10)

    def forward(self, inputs):
        inputs = torch.flatten(inputs, 1)
        outputs = F.relu(self.fc1(inputs))
        return F.relu(self.fc2(outputs))

backend = cstorch.backend(
    "CSX",
    cluster_config=cstorch.distributed.ClusterConfig(
        mgmt_address=mgmt_address,
        max_wgt_servers=1,
        max_act_per_csx=1,
        num_workers_per_csx=1,
    ),
)

model = Model()

# Compile the model
compiled_model = cstorch.compile(model, backend=backend)

# Define an optimizer
optimizer = cstorch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)


# Define a data loader
def get_torch_dataloader(batch_size, train):
    from torchvision import datasets, transforms

    train_dataset = datasets.MNIST(
        "./data",
        train=train,
        download=True,
        transform=transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        ),
        target_transform=transforms.Lambda(
            lambda x: torch.as_tensor(x, dtype=torch.int32)
        ),
    )

    return torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True
    )


training_dataloader = cstorch.utils.data.DataLoader(
    get_torch_dataloader, batch_size=64, train=True
)

# Define the training step
loss_fn = torch.nn.CrossEntropyLoss()

@cstorch.trace
def training_step(inputs, targets):
    outputs = compiled_model(inputs)
    loss = loss_fn(outputs, targets)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    return loss


@cstorch.step_closure
def print_loss(loss: torch.Tensor, step: int):
    print(f"Train Loss {step}: {loss.item()}")


@cstorch.checkpoint_closure
def save_checkpoint(step):
    cstorch.save(
        {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        },
        f"checkpoint_{step}.mdl",
    )

global_step = 0

train_executor = cstorch.utils.data.DataExecutor(
    training_dataloader,
    num_steps=100,
    checkpoint_steps=50,
)
model.train()
for inputs, targets in train_executor:
    loss = training_step(inputs, targets)
    print_loss(loss, global_step)
    global_step += 1
    save_checkpoint(global_step)

Further Reading