Sparse models require fewer FLOPs per step and less memory to store, making them compelling architectures for deep learning research and application. Sparse training algorithms set portions of model weights to zero and often adjust these proportions throughout training. The Cerebras PyTorch API enables practitioners to easily implement and extend sparsity algorithms and schedules, and directly train with sparse weights on the Cerebras system. This process, which involves setting a proportion of the model’s weights to zero, not only streamlines the model by focusing on essential features but also aids in regularization, potentially improving generalization to new data. Various strategies exist to implement sparsity, each suitable for different models and tasks, making sparsity a versatile tool in optimizing machine learning models for performance and efficiency.

The Cerebras PyTorch API enables practitioners to train with sparse weights, leading to a significant reduction in FLOPs per step.

How to Sparsify Your Model

Sparsifying a model is straightforward with the cstorch API. Here’s an example where 30% of the values in every parameter (such as weights, biases, embeddings, among others) are set to zero prior to training:

import cerebras.pytorch as cstorch

# Initialize your backend
backend = cstorch.backend(...)

# Define your model within the backend's device context
with backend.device:
    model: torch.nn.Module = ...

# Compile the model with the backend
compiled_model = cstorch.compile(model, backend)

# Initialize a sparsity algorithm with 30% sparsity
sparsity = cstorch.sparse.Static(sparsity=0.3)
# Apply the sparsity algorithm to the model
model.apply(sparsity)

After the call to model.apply(sparsity), your model parameters are sparsified, enhancing training efficiency.

Important considerations

1. Only once the model has been compiled, apply sparsity with cstorch.compile, ensuring all parameters are on the Cerebras device.

2. To exclude certain parameters from sparsity, set param.requires_dense = True. If a parameter does not have this attribute, the algorithm assumes that it is False.

Sparsifying Optimizers

For training, simply sparsifying the model’s parameters is insufficient; the optimizer must be sparsified as well. To extend sparsity to your optimizer:

optimizer.apply(sparsity)

When you sparsify an optimizer, you’re not only adjusting the optimizer’s state but also setting up mechanisms such as installing various hooks to ensure that sparsity patterns are maintained and updated appropriately during training, specifically when optimizer.step() is executed.

Executing optimizer.apply(sparsity) transforms your optimizer into a sparse optimizer.

Important considerations

  1. The sparsity algorithm targets all optimizer states associated with a parameter, assuming the state tensor matches the parameter’s shape. To exempt certain state tensors from sparsification, designate them as requiring to be dense:
optmizer.state[p][name].requires_dense = True

If a state tensor does not have this attribute, the algorithm assumes that it is False.

  1. Sparsity algorithms typically include a hook that updates the sparsity pattern after each optimizer.step() call. This automatic update feature can be deactivated if necessary:
sparsity.autoupdate = False

Sparsity Algorithms

The cstorch API offers several out-of-the-box sparsity algorithms, including:

Composing Sparsity Algorithms

You can apply distinct sparsity strategies to various parameter groups within your model. For instance, one group of weights might be statically reduced to 30% of its original values, while another group undergoes dynamic sparsity adjustments using the SET algorithm. This can be achieved by composing different sparsity algorithms into a composable strategy with the Group class.

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(25, 20)
        self.fc2 = torch.nn.Linear(20, 10)
    ...

sparsity = cstorch.sparse.Group({
    "fc1.*": cstorch.sparse.Static(sparsity=0.3)
    "fc2.*": cstorch.sparse.SET(
        sparsity=0.9, update={"freq": 1000}, drop_fraction=0.3
    )
})

model.apply(sparsity)

This grouped sparsity algorithm applies static sparsity to all model parameters corresponding to the fc1.* glob pattern, while employing the SET sparsity algorithm for parameters that match the fc2.* glob pattern.

Writing Custom Sparsity Algorithms

All sparsity algorithms must inherit from the base SparsityAlgorithm.

class CustomSparsity(cstorch.sparse.SparsityAlgorithm):

    def __init__(self, ..., **kwargs):
        ...
        super().__init__(**kwargs)

    ...

    def update(self):
        ...
        for sparse_param in self.sparse_params.values():
            p = sparse_param.data
            mask = sparse_param.mask
            ...
            sparse_params.mask = new_mask
        ...

The only abstract method that must be overriden is update which takes care of updating the sparsity patterns for all sparse parameters.

For algorithms that dynamically change the sparsity pattern, there is a convenient DynamicSparsityAlgorithm class that you can inherit from that takes care of many of the implementation details required to facilitate dynamic sparsity.

class CustomDynamicSparsity(cstorch.sparse.DynamicSparsityAlgorithm):

    def __init__(self, ..., **kwargs):
        ...
        super().__init__(**kwargs)

    ...

    def update_mask(p, mask, sparsity) -> torch.Tensor:
        ...

DynamicSparsityAlgorithm already implements update, but it exposes a :py:new abstract method update_mask that :py:must be overriden instead. update_mask takes :py:in the existing sparsity pattern in the form of a mask tensor and must :py:return the new sparsity pattern in the form of a mask tensor as well.

See GMP, SET, and RigL for examples of how to implement update_mask.

In addition, there are many building blocks that are provided that can be used directly, inherited from, or composed to help build new DynamicSparsityAlgorithm subclasses. See Customizing Sparsity & Reference for more details.

Once you’ve written your custom sparsity algorithm, as long as it’s available in the global scope, you can use it directly or even through a call to configure by setting the algorithm to be the name of your custom sparsity algorithm class. By extension, this means that you can use it in ModelZoo in a similar way by setting the algorithm to be the name of your custom sparsity algorithm class in your params YAML file (see sparsity_via_yaml for more details).

Implementation Notes

The Cerebras Wafer-Scale Cluster natively implements sparse computations in the Compressed Sparse Row (CSR) format. For user convenience, sparse models are represented as a combination of dense tensors and masks at the PyTorch level, with the compiler seamlessly converting between these representations.

While PyTorch provides tools for representing sparse tensors and utilities for pruning networks, these features might not fully align with the needs of the Cerebras Wafer Scale Engine (WSE). Sparse tensors in PyTorch require specialized kernels and may not be entirely compatible with existing models and utilities. Notably, a torch.nn.Parameter cannot directly accommodate a torch.sparse.Tensor without specific adjustments. The torch.prune utilities are convenient, but the asynchronous and precompiled nature of computation on the WSE requires a custom solution.

Similar to how torch.prune handles its mask tensors, when the sparsity algorithm is applied to the model, every parameter that is sparsified has a mask tensor registered as a stateful buffer next to it in the module that owns the parameter.

For example, take the following simple model:

model = torch.nn.Linear(2, 2, bias=False)

Initially, the model’s state dictionary appears as follows:

{
    "weight": torch.Tensor(...)
}

After applying sparsity, the state dictionary is augmented to include mask tensors, illustrating the model’s transition to a sparsified state:

{
    "weight": torch.Tensor(...)
    "weight_mask": torch.Tensor(...)
}

Here, the weight and weight_mask tensors collectively represent the sparsified weight, showing how sparsity is represented within the model’s architecture.