Sparsifying Models
Learn strategies for integrating sparsity into Cerebras models to optimize performance and computational efficiency across neural network architectures.
Sparse models require fewer FLOPs per step and less memory to store, making them compelling architectures for deep learning research and application. Sparse training algorithms set portions of model weights to zero and often adjust these proportions throughout training. The Cerebras PyTorch API enables practitioners to easily implement and extend sparsity algorithms and schedules, and directly train with sparse weights on the Cerebras system. This process, which involves setting a proportion of the model’s weights to zero, not only streamlines the model by focusing on essential features but also aids in regularization, potentially improving generalization to new data. Various strategies exist to implement sparsity, each suitable for different models and tasks, making sparsity a versatile tool in optimizing machine learning models for performance and efficiency.
The Cerebras PyTorch API enables practitioners to train with sparse weights, leading to a significant reduction in FLOPs per step.
How to Sparsify Your Model
Sparsifying a model is straightforward with the cstorch API. Here’s an example where 30% of the values in every parameter (such as weights, biases, embeddings, among others) are set to zero prior to training:
After the call to model.apply(sparsity)
, your model parameters are sparsified, enhancing training efficiency.
Important considerations
1. Only once the model has been compiled, apply sparsity with cstorch.compile
, ensuring all parameters are on the Cerebras device.
2. To exclude certain parameters from sparsity, set param.requires_dense = True
. If a parameter does not have this attribute, the algorithm assumes that it is False
.
Sparsifying Optimizers
For training, simply sparsifying the model’s parameters is insufficient; the optimizer must be sparsified as well. To extend sparsity to your optimizer:
When you sparsify an optimizer, you’re not only adjusting the optimizer’s state but also setting up mechanisms such as installing various hooks to ensure that sparsity patterns are maintained and updated appropriately during training, specifically when optimizer.step()
is executed.
Executing optimizer.apply(sparsity)
transforms your optimizer into a sparse optimizer.
Important considerations
- The sparsity algorithm targets all optimizer states associated with a parameter, assuming the state tensor matches the parameter’s shape. To exempt certain state tensors from sparsification, designate them as requiring to be dense:
If a state tensor does not have this attribute, the algorithm assumes that it is False
.
- Sparsity algorithms typically include a hook that updates the sparsity pattern after each
optimizer.step()
call. This automatic update feature can be deactivated if necessary:
Sparsity Algorithms
The cstorch API offers several out-of-the-box sparsity algorithms, including:
Composing Sparsity Algorithms
You can apply distinct sparsity strategies to various parameter groups within your model. For instance, one group of weights might be statically reduced to 30% of its original values, while another group undergoes dynamic sparsity adjustments using the SET algorithm. This can be achieved by composing different sparsity algorithms into a composable strategy with the Group
class.
This grouped sparsity algorithm applies static sparsity to all model parameters corresponding to the fc1.*
glob pattern, while employing the SET sparsity algorithm for parameters that match the fc2.*
glob pattern.
Writing Custom Sparsity Algorithms
All sparsity algorithms must inherit from the base SparsityAlgorithm
.
The only abstract method that must be overriden is update
which takes care of updating the sparsity patterns for all sparse parameters.
For algorithms that dynamically change the sparsity pattern, there is a convenient DynamicSparsityAlgorithm
class that you can inherit from that takes care of many of the implementation details required to facilitate dynamic sparsity.
DynamicSparsityAlgorithm
already implements update
, but it exposes a :py:new abstract method update_mask
that :py:must be overriden instead. update_mask
takes :py:in the existing sparsity pattern in the form of a mask tensor and must :py:return the new sparsity pattern in the form of a mask tensor as well.
See GMP
, SET
, and RigL
for examples of how to implement update_mask
.
In addition, there are many building blocks that are provided that can be used directly, inherited from, or composed to help build new DynamicSparsityAlgorithm
subclasses. See Customizing Sparsity & Reference for more details.
Once you’ve written your custom sparsity algorithm, as long as it’s available in the global scope, you can use it directly or even through a call to configure
by setting the algorithm
to be the name of your custom sparsity algorithm class. By extension, this means that you can use it in ModelZoo in a similar way by setting the algorithm
to be the name of your custom sparsity algorithm class in your params YAML file (see sparsity_via_yaml for more details).
Implementation Notes
The Cerebras Wafer-Scale Cluster natively implements sparse computations in the Compressed Sparse Row (CSR) format. For user convenience, sparse models are represented as a combination of dense tensors and masks at the PyTorch level, with the compiler seamlessly converting between these representations.
While PyTorch provides tools for representing sparse tensors and utilities for pruning networks, these features might not fully align with the needs of the Cerebras Wafer Scale Engine (WSE). Sparse tensors in PyTorch require specialized kernels and may not be entirely compatible with existing models and utilities. Notably, a torch.nn.Parameter
cannot directly accommodate a torch.sparse.Tensor
without specific adjustments. The torch.prune
utilities are convenient, but the asynchronous and precompiled nature of computation on the WSE requires a custom solution.
Similar to how torch.prune
handles its mask tensors, when the sparsity algorithm is applied to the model, every parameter that is sparsified has a mask tensor registered as a stateful buffer next to it in the module that owns the parameter.
For example, take the following simple model:
Initially, the model’s state dictionary appears as follows:
After applying sparsity, the state dictionary is augmented to include mask tensors, illustrating the model’s transition to a sparsified state:
Here, the weight
and weight_mask
tensors collectively represent the sparsified weight
, showing how sparsity is represented within the model’s architecture.