Skip to main content

Documentation Index

Fetch the complete documentation index at: https://training-docs.cerebras.ai/llms.txt

Use this file to discover all available pages before exploring further.

Weight Streaming is the execution mode used by the Cerebras Wafer-Scale Cluster to deploy neural network models onto the Wafer-Scale Engine (WSE). In this mode, the model is loaded one layer at a time rather than all at once. This layer-by-layer approach supports models where a single layer’s weights fit in WSE memory but the full model does not.
To demonstrate the functioning of Weight Streaming, let’s use the following example of a 3-layer FC-MNIST network:

Weight Streaming Execution during runtime

During runtime, the Weight Streaming process unfolds as depicted in Fig. 3 below. At each step, one layer is loaded onto the Cerebras Wafer-Scale Engine (WSE). The weight data flows from the MemoryX server to the WSE during the forward propagation phase. In the subsequent backpropagation phase, the weights are once again streamed from MemoryX to the WSE. Here on the WSE, weight gradients are computed, and then they are streamed from the WSE to MemoryX for storage. These weight gradients are vital for the learning process, during which the weights are adjusted using the weight gradient in conjunction with the learning algorithm.

Compilation

The Cerebras compiler plays a crucial role by extracting the operation graph from the code and aligning these operations with the compatible kernels within the Cerebras Software Platform. Each such matched kernel corresponds to a layer in the network’s dataflow graph. You can visualize a mathematical representation of these layers in Fig. 4. To inspect the results of this compilation phase, use the --validate_only flag. The compiler maps one kernel (layer) at a time onto the entire WSE, first for the forward pass and then for the backward pass. When multiple CS-X systems are requested, the same mapping is used for every layer across all systems. To run only the compilation step without executing training, use the --compile_only flag.

Training

During the forward propagation phase, as illustrated in Fig. 5:
  • Layer 1 is the first to be loaded onto the WSE. Input pre-processing servers handle the processing and streaming of training data to the WSE, while MemoryX streams the weights of layer 1 into the WSE. If there’s a request for multiple CS-X in the training job, SwarmX broadcasts the weights from MemoryX to all the WSEs. The batch of training samples is divided into equally sized subsets, or shards, with each shard directed to a respective CS-X. This technique is known as data parallelism.
  • Each WSE simultaneously performs the forward computation for layer 1 in parallel with the others.
  • The calculated activations for layer 1 are retained in the WSE’s memory.
  • Subsequently, MemoryX broadcasts the weights of layer 2 to the WSEs.
  • Each WSE conducts the forward computations for layer 2, utilizing its stored activations from layer 1.
  • This same process is repeated for layer 3.
  • In this manner, the forward computation for each layer is carried out by employing the previously computed activations from the preceding layer. Additionally, the calculated activations for the current layer are stored in the WSE’s memory for use by the next loaded layer.
  • At the loss layer, the actual labels from the training data are employed to compute the network loss delta, which represents the gradient of the scalar loss concerning the activation of the output layer (layer 3). This loss delta plays a crucial role in calculating layer-by-layer deltas and weight gradients during the subsequent backward pass.
In the backward propagation phase, as illustrated in Fig. 6:
  • The weights of layer three are broadcast from MemoryX to the WSEs, where the gradient and delta computations for layer three are performed. It’s worth noting that, in practice, the weights of this output layer can be retained to save time.
  • The gradients for layer three are then streamed out of the WSEs to SwarmX. SwarmX aggregates (sums together) the gradients from multiple WSEs and presents this sum to MemoryX. MemoryX employs this aggregate gradient within the learning algorithm to update its stored copy of the layer weights.
  • Subsequently, the weights of layer two are streamed from MemoryX to the WSE, and the WSE performs similar gradient and delta computations for layer two. The layer two gradients are then streamed out of the WSE to MemoryX, where weight updates take place. If there’s a request for multiple CS-X in the training job, SwarmX broadcasts the weights from MemoryX to all the WSEs and consolidates the gradient updates from the WSEs to MemoryX.
  • This backward pass sequence continues until the gradients for layer one are streamed out to MemoryX, where weight updates occur. Following this, the forward pass for the next training batch commences with the updated weights.
Meanwhile, within the user node:
  • As the loss values are calculated on the WSEs, SwarmX consolidates them and then transmits the results to the user node.
  • Additionally, at specified intervals, all the weights can be retrieved from MemoryX and transferred to the user node for checkpoint purposes.