Train A Model With Weight Sparsity
This page will cover how to configure the with weight sparsity.
This page will cover how to configure the Trainer
with weight sparsity. By the end, you should be familiar with how to use sparsity in tandem with the Trainer
for any model.
Prerequisites
Make sure to have read through Trainer Overview and Trainer Configuration Overview which provide the basic overview of how to run Model Zoo models. In this document, you will be using the tools and configurations outlined in those pages.
Background
In 2018, state-of-the-art neural networks such as BERT had a few hundred million parameters. Two years later, the world was introduced to GPT-3. With 175 billion parameters and a 3.14*1023 FLOPs (floating point operations) compute budget, it is estimated to have required 10,000 NVIDIA V100 GPUs for 15 days, accounting for 552 tons of CO2e emissions and 1,287 MWh of energy [Patterson et al.].
Evidently, training large models is costly. With parameter counts and datasets getting larger and larger every year, new approaches are needed to reduce the time, energy, and carbon footprint required to train. Weight sparsity, coupled with hardware that accelerates it, is a promising way to train models using significantly less compute and memory.
Weight sparse training methods set subsets of weights to zero. The resulting sparse model requires far fewer FLOPs to train and fewer parameters to store, as multiplies with zeros get skipped on both forward and backward passes through the network. Only systems that can accelerate sparsity, such as Cerebras CS-X and CS-3, can take advantage of the lower resource requirement and use the reduction in FLOPs to significantly accelerate training. Finding and training sparse models to match the accuracy of their original “dense” (i.e., non-sparse) configurations is an active and open area of research!
Configure Sparsity
Let’s expand on the minimal example shown in sparsity.
For example, with the following config, the sparsity level is set to 0.3 (30%), and init_method
is "random"
, which means 30% of the elements in each Parameter (which passes the default parameter filter) will be pruned once at model initialization and kept that way throughout training. Non-Parameter tensors are not pruned.
Sparsity is parameterized primarily by the following keys:
-
algorithm
:You can also define a custom class that inherits fromSparsityAlgorithm
.
-
YAML: As long as the class in the global scope, i.e. by importing it in your run.py, it can be directly used in a YAML config, e.g.
-
Python: The CustomSparsity class can be passed directly to the Trainer as seen below.
See Writing a Custom Sparsity Algorithm for more details on how to write a custom sparsity algorithm.
sparsity
:The desired sparsity level between 0 and 1. 0.0 means the Parameter is kept fully dense. 1.0 means the Parameter is effectively entirely zeros. Dynamic sparsity algorithms also accept more complex configuration described below in Dynamic Hyperparameters.
The actual sparsity level may not match the target sparsity level in practice. The target sparsity level only represents a target distribution. The true sparsity level is determined by the size of the Parameter that is being sparsified.
For example, if you were to sparsify a Parameter with shape (5,) targeting a sparsity level of 0.5, the actual sparsity level will only ever be 0.4. The smaller the Parameter, the more extreme this discrepancy becomes. If the Parameter is a scalar tensor, then the actual sparsity level will always either be 0.0 or 1.0.
-
init_method
optional:Method to compute the initial sparsity distribution.-
random
: (default) Sparsity is randomly distributed within each weight. -
topk
: Sparsity is distributed according to the lowest magnitude weights. -
from_zeros
: Sparsity pattern is determined by weight values that are already zero.
-
-
param_filter
optional:Controls which Parameters are sparsified. The list of Parameter names can be found usingmodel.named_parameters()
. When this is omitted, any multidimensional Parameters (except those withembedding
,norm
, orlm_head
in their name) automatically get sparsity applied (single dimensional weights such as biases are ignored) (Seedefault_sparse_param_filter
).While this provides a good default heuristic for transformer based models 1, a (list of) glob expressions can also be provided to only apply sparsity to Parameters which match, e.g.```Bash trainer: init: sparsity: … param_filter:-
“*dense_layer.weight”
-
“*linear_layer.weight” ```To match all weights, set
param_filter: *
Per-layer sparsity options can be configured by passing in a list of configuration dictionaries. See below in advanced param_filters.
-
Dynamic Sparsity Update Schedule
Dynamic sparsity (e.g. GMP
, SET
, or RigL
) needs an additional update
schedule indicating when to update the sparsity pattern. There are 2 basic methods built-in with 3 different options:
Regular Interval
When sparsity should be updated at a regular interval, a single frequency can be given:
Here, sparsity will be initialized at 90% and steps 0,…,99 will be performed with a fixed sparsity pattern. Every 100 steps, the sparsity pattern will be updated according to the SET algorithm.
To control beginning and ending steps, use a dictionary. In the following example, sparsity will be initialized at 0% and steps 0,…,76 will be performed without sparsity. Starting from step 77 and every 100 steps until step 377, the sparsity pattern will be updated according to the SET algorithm. After step 377, the sparsity pattern will continue to be applied, but it will no longer be updated (stop is exclusive).
Irregular Interval
When sparsity should be updated at arbitrary steps, specify them in a list:
Dynamic Hyperparameters
Dynamic sparsity algorithms (e.g. GMP
, SET
, or RigL
) can configure the sparsity
(and drop_fraction
for SET
and RigL
) field using a “step aware hyperparemeter” akin to learning rate schedules in addition to simple constants. These more complex configurations usually require additional options and so are specified as dictionaries.
Note
The base DynamicSparsityAlgorithm
that invokes such a dynamic hyperparameter for sparsity
ensures sparsity levels stay legal by using torch.clamp(sparsity, min=0.0, max=1.0)
.
Linear
Exponential
This is expecially useful for GMP
, where the sparsity level monotonically increases throughout training because a fraction of the remaining elements in the Parameter are pruned at each update step, asymptotically approaching an empty network.
Cosine
This is especially useful for RigL
, which usually uses a “cosine decay” on its drop_fraction
. minimum
defaults to 0.0
. half_period
controls what step the value reaches its minimum.
More Config examples
The most basic configuration, applying random 30% sparsity to all Parameters:
Apply uniform (static) sparsity to a selected set of weights, with a sparsity pattern guided by the weight magnitudes:
Basic dynamic sparsity using the SET algorithm. Update the sparsity pattern every 1000 iterations.
Configuring Multiple Sparsity Algorithms
Different groups of Parameters can be sparsified using different sparsity algorithms.
For example, if one set of weights should be statically sparsified to say 0.3
, but another set of weights should be dynamically sparsified using the SET algorithm, it can be done by providing a list of sparsity algorithms.
Advanced param_filters
When each Parameter (or group of Parameters) needs different configuration, param_filters
can be specified as a dictionary, mapping “patterns” to the config dictionaries to overlay on the default sparsity config options.
For example, when using RigL on transformer networks (uses gradient information to guide which values in a Parameter to prune), sparsity can be cyclically restributed between the heads of attention projection weights in case samples in a batch activate one head disproportionately to another. This ultimately decreases the effectiveness of dynamic sparsity and even can hurt model performance.
To ensure sparsity is fairly distributed between the different attention heads of the multi-head attention projections, you can specify balance_out_groups
when the output logits are logically N independent/stacked groups (i.e. input projection weights before multi-head attention QKV), or balance_in_groups
for the reverse (i.e. output projection weights). These should apply differently to different weights using param_filter
since this conceptually only applies to Attention projection weights. In the following example, the model has 12 attention heads.
Running a Sparse Model
No change is needed to the run
command (see guide: Launch your job) - ensure the .yaml
file has sparsity enabled. To validate your sparsity config before launching training, run with --validate_only
. You can also log which weights are being sparsified by passing --logging VERBOSE
to your run command.
When using dynamic sparsity, you can see realtime summaries by using the LogSparsity
callback.
YAMLPython
Sparsity via API
Please see Sparsifying models for more details on how to configure sparsity using the Cerebras PyTorch API.
Related research
Sparsity is a powerful tool that can improve performance, reduce the model size, and help generalizability while achieving the same accuracy as densely trained models. Read some of our research work on training with sparsity here:
- Sparse pre-training and dense fine-tuning (SPDF): blog, arxiv, PMLR. This work shows how we can pre-train a 1.3B parameter GPT-3 style model with up to 75% unstructured sparsity and 60% fewer training FLOPs on Cerebras CS-X, without significantly losing accuracy on downstream tasks.
- Sparse Iso-FLOP transformations for Maximizing Training Efficiency (Sparse-IFT): blog, arxiv. This work shows how pre-training a GPT-3 style small model with sparsity leads to a 0.4 perplexity improvement on the WikiText103 language modeling task. See the table below and paper for more details, including results on computer vision tasks.
- Variable Sparse pre-training and dense fine-tuning (Variable SPDF): blog. This work extends SPDF, demonstrating the scaling up to a 6.7B GPT-3 style model with a 64% FLOPs reduction while maintaining downstream model performance. This is accomplished by performing the majority of pre-training at static sparsity, then finishing pre-training on a dense version of the model before fine-tuning on the dense version as well.
Footnotes
Sparsity on CSX has not yet been thoroughly validated with convolutional networks.