Our typical workflow involves using a training script provided in the Cerebras Model Zoo. However, if that training loop is insufficient for your model needs, you may write your own training loop using the Cerebras PyTorch API.

Proceed with the following steps to learn how to write a custom training loop for a simple, fully connected model for training on the MNIST dataset.

Note, the following steps will only take you through the absolute minimum code required to run a simple, small model on the Cerebras Wafer Scale Cluster. To extend the script to feature things like learning rate scheduling, gradient scaling, etc. please continue to the further reading section to learn more about these topics.

Prerequisites

You have installed the cerebras.pytorch package in your environment.

Validate the Package Installation

To check whether the cerebras.pytorch package is installed correctly, issue the following command:

import cerebras.pytorch as cstorch

From here on, we will be using cstorch as the alias for cerebras.pytorch

Define Your Model

When using the Cerebras PyTorch API, you can define your model in the same way you would in a Vanilla PyTorch workflow:

import torch
import torch.nn.functional as F


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(784, 256)
        self.fc2 = torch.nn.Linear(256, 10)

    def forward(self, inputs):
        inputs = torch.flatten(inputs, 1)
        outputs = F.relu(self.fc1(inputs))
        return F.relu(self.fc2(outputs))


model = Model()

Weight initialization for large models can cause out-of-memory errors. Not only that, but initializing extremely large models eagerly can be very slow. See the page on Efficient weight initialization to see how to work around this issue.

Compile Your Model

Once the model has been instantiated, compile the model by calling the cerebras.pytorch.compile, e.g.

compiled_model = cstorch.compile(model, backend="CSX")

You must pass in the backend you wish to compile the model with. You can simply pass in the type of backend if you wish to use all default arguments, or you can instantiate a backend object using cerebras.pytorch.backend to customize it, e.g.

backend = cstorch.backend("CSX", artifact_dir="./artifact_dir")
compiled_model = cstorch.compile(model, backend)

The call to cstorch.compile doesn’t actually compile the model. Similar to torch.compile it only prepares the model for compilation. Compilation only happens after the first iteration once the input shapes are known.

Optimize Model Parameters

To optimize model parameters using the Cerebras Wafer-Scale cluster, you must use a Cerebras-compliant optimizer. There are exact drop-in replacements for all commonly used optimizers available in cerebras.pytorch.optim, e.g.

optimizer = cstorch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

If you are interested in writing your own Cerebras custom-compliant optimizer, see the page on Writing custom optimizers

DataLoaders

To send data to the Wafer-Scale cluster, you must wrap your PyTorch dataloader with cerebras.pytorch.utils.data.DataLoader, e.g.

def get_torch_dataloader(batch_size, train):
    from torchvision import datasets, transforms

    train_dataset = datasets.MNIST(
        "./data",
        train=train,
        download=True,
        transform=transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        ),
        target_transform=transforms.Lambda(
            lambda x: torch.as_tensor(x, dtype=torch.int32)
        ),
    )

    return torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True
    )


training_dataloader = cstorch.utils.data.DataLoader(
    get_torch_dataloader, batch_size=64, train=True
)

The Cerebras PyTorch dataloader takes in some callable that returns a PyTorch dataloader. It must be done this way so that every single worker can create their own PyTorch dataloader instance to maximize distributed parallelism.

Define the Training Step

To run a single training iteration on the Cerebras Wafer-Scale cluster, we must first, capture everything that is intended to run on the cluster. To do this, define a function which contains everything that happens in a single training iteration, and decorate it using cerebras.pytorch.trace.

For example:

loss_fn = torch.nn.CrossEntropyLoss()

@cstorch.trace
def training_step(inputs, targets):
    outputs = compiled_model(inputs)
    loss = loss_fn(outputs, targets)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    return loss

This function gets traced and sent to the cluster for compilation and execution.

Define an Execution

To program an execution run on the Cerebras Wafer-Scale cluster, you must define an instance of the cerebras.pytorch.utils.data.DataExecutor, e.g.

train_executor = cstorch.utils.data.DataExecutor(
    training_dataloader,
    num_steps=100,
    checkpoint_steps=50,
)

It takes in the Cerebras PyTorch dataloader that will be used during the run, the total number of steps to run for, as well as the interval at which checkpoints will be taken.

Configuring the Cerebras Wafer Scale Cluster

To configure the Cerebras Wafer-Scale cluster, construct a CSConfig object:

cs_config = cstorch.utils.CSConfig(
    mgmt_address=mgmt_address,
    transfer_processes=1,
    max_wgt_servers=1,
    max_act_per_csx=1,
    num_workers_per_csx=1,
)

which can be passed to the cerebras.pytorch.utils.data.DataExecutor object

train_executor = cstorch.utils.data.DataExecutor(
    training_dataloader,
    num_steps=100,
    checkpoint_steps=50,
    cs_config=cs_config,
)

See the class documentation for CSConfig for all the options configurable.

Most options have reasonable defaults and do not need to be changed.

Train Your Model

Once the above is defined, you can iterate through the executor to train your model.

@cstorch.step_closure
def print_loss(mode, loss: torch.Tensor, step: int):
    print(f"{mode} Loss {step}: {loss.item()}")


@cstorch.checkpoint_closure
def save_checkpoint(step):
    cstorch.save(
        {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        },
        f"checkpoint_{step}.mdl",
    )

global_step = 0

for inputs, targets in train_executor:
    loss = training_step(inputs, targets)
    print_loss("Training", loss, global_step)
    global_step += 1
    save_checkpoint(global_step)
  • Notice how the loss was passed into a function decorated by step_closure. This is required to retrieve the loss value from the Cerebras Wafer Scale Cluster before it can be used. Please see the page on step closures for more details.

  • Also, notice how checkpoints are saved inside a function decorated by checkpoint_closure. This is required to retrieve the model weights and optimizer state back from the Cerebras Wafer Scale Cluster before it can be saved. Please see the page on saving checkpoints.

Putting it All Together

Combining all of the above steps, we can create a super minimal training script for a simple, fully connected model training on the MNIST dataset:

# Import the Cerebras PyTorch module
import cerebras.pytorch as cstorch

# Define a model
import torch
import torch.nn.functional as F


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(784, 256)
        self.fc2 = torch.nn.Linear(256, 10)

    def forward(self, inputs):
        inputs = torch.flatten(inputs, 1)
        outputs = F.relu(self.fc1(inputs))
        return F.relu(self.fc2(outputs))


model = Model()

# Compile the model
compiled_model = cstorch.compile(model, backend="CSX")

# Define an optimizer
optimizer = cstorch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)


# Define a data loader
def get_torch_dataloader(batch_size, train):
    from torchvision import datasets, transforms

    train_dataset = datasets.MNIST(
        "./data",
        train=train,
        download=True,
        transform=transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        ),
        target_transform=transforms.Lambda(
            lambda x: torch.as_tensor(x, dtype=torch.int32)
        ),
    )

    return torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True
    )


training_dataloader = cstorch.utils.data.DataLoader(
    get_torch_dataloader, batch_size=64, train=True
)

# Define the training step
loss_fn = torch.nn.CrossEntropyLoss()

@cstorch.trace
def training_step(inputs, targets):
    outputs = compiled_model(inputs)
    loss = loss_fn(outputs, targets)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    return loss


@cstorch.step_closure
def print_loss(loss: torch.Tensor, step: int):
    print(f"Train Loss {step}: {loss.item()}")


@cstorch.checkpoint_closure
def save_checkpoint(step):
    cstorch.save(
        {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        },
        f"checkpoint_{step}.mdl",
    )


cs_config = cstorch.utils.CSConfig(
    mgmt_address=mgmt_address,
    transfer_processes=1,
    max_wgt_servers=1,
    max_act_per_csx=1,
    num_workers_per_csx=1,
)

global_step = 0

train_executor = cstorch.utils.data.DataExecutor(
    training_dataloader,
    num_steps=100,
    checkpoint_steps=50,
    cs_config=cs_config,
)
model.train()
for inputs, targets in train_executor:
    loss = training_step(inputs, targets)
    print_loss(loss, global_step)
    global_step += 1
    save_checkpoint(global_step)

Further Reading