On this page, you will learn how to set up Multi-Phase training using the Trainer
class. Multi-Phase training allows you to combine multiple training phases with different batch sizes or max sequence lengths in a single config file or python script.
Prerequisites
Please ensure that you have read through the next tutorials beforehand:
The rest of this page assumes that you already have at least a cursory understanding of what the Cerebras Model Zoo Trainer is and how to use the Python API.
Multi-Phase Training
In Multi-Phase training, you may want to define several distinct training phases. For example, the training pipeline for the Llama-3 model might involve varying batch sizes or max sequence lengths across different phases. Each of these phases is defined by an instance of the Trainer
.
Let’s consider an example. In the Pretraining with Upstream Validation, you’ve learned how to construct the Trainer for the Llama-3 model. Now, let’s add a new training phase with a different batch size and new max sequence length.
To define each phase you need to construct a separate Trainer
instance. For example:
trainer:
- trainer:
...
- trainer:
...
The number of Trainer
instances is not limited and each Trainer
can have different parameters, so you can construct arbitrary training/validation pipelines including different models, dataloders, etc.
For each phase we define different batch size and different max sequence lengths.
trainer:
- trainer:
init: &init
backend:
backend_type: CSX
cluster_config:
num_csx: 16
seed: 2024
model:
# Embedding
vocab_size: 128256
hidden_size: 4096
position_embedding_type: "rotary"
pos_scaling_factor: 1.0
rope_theta: 500000.0
rotary_dim: 128
share_embedding_weights: false
max_position_embeddings: 8192
embedding_dropout_rate: 0.0
embedding_layer_norm: false
# Decoder
num_hidden_layers: 32
dropout_rate: 0.0
layer_norm_epsilon: 1.0e-5
norm_type: "rmsnorm"
# Decoder - Attention
num_heads: 32
attention_type: "scaled_dot_product"
attention_module: "multiquery_attention"
attention_dropout_rate: 0.0
use_projection_bias_in_attention: false
use_ffn_bias_in_attention: false
extra_attention_params:
num_kv_groups: 8
# Decoder - ffn
filter_size: 14336
nonlinearity: "swiglu"
use_ffn_bias: false
# Task-specific
use_bias_in_output: false
loss_scaling: "num_tokens"
loss_weight: 1.0
# Initializer
initializer_range: 0.02
# Cerebras parameters
mixed_precision: True
fp16_type: "cbfloat16"
optimizer:
AdamW:
betas: [0.9, 0.95]
correct_bias: True
weight_decay: 0.1
schedulers:
- CosineDecayLR:
initial_learning_rate: 3.0e-5
end_learning_rate: 3.0e-6
total_iters: 528
precision:
fp16_type: cbfloat16
loss_scaling_factor: dynamic
max_gradient_norm: 1.0
loop:
num_steps: 10000
eval_frequency: 1000
eval_steps: 1000
checkpoint:
steps: 1000
callbacks:
- ComputeNorm: {}
- CheckLoss: {}
- ModelEvalMetrics: {}
loggers:
- ProgressLogger: {}
- TensorBoardLogger: {}
fit:
train_dataloader:
data_processor: GptHDF5MapDataProcessor
data_dir: "/data/llama_v3_dataset_vocab128256_msl8192/train"
batch_size: 80
shuffle: False
shuffle_seed: 1337
num_workers: 8
prefetch_factor: 10
persistent_workers: True # Important to avoid seeding at each epoch
val_dataloader:
- data_processor: GptHDF5MapDataProcessor
data_dir: "/data/llama_v3_dataset_vocab128256_msl8192/val"
batch_size: 80
shuffle: False
shuffle_seed: 1337
num_workers: 8
prefetch_factor: 10
persistent_workers: True # Important to avoid seeding at each epoch
- trainer:
init:
<<: *init
fit:
train_dataloader:
data_processor: GptHDF5MapDataProcessor
data_dir: "/data/llama_v3_dataset_vocab128256_msl512/train"
batch_size: 40
shuffle: False
shuffle_seed: 1337
num_workers: 8
prefetch_factor: 10
persistent_workers: True # Important to avoid seeding at each epoch
val_dataloader:
- data_processor: GptHDF5MapDataProcessor
data_dir: "/data/llama_v3_dataset_vocab128256_msl512/val"
batch_size: 40
shuffle: False
shuffle_seed: 1337
num_workers: 8
prefetch_factor: 10
persistent_workers: True # Important to avoid seeding at each epoch
It’s important to note that when using YAML, you have to construct a Trainer
instance for each phase, which adds some overhead to your run due to time spent on compile and weights transfer. If you are using Python API, you can construct a single Trainer
object and call fit
using different DataLoader
objects.
Multi-Phase Training (Advanced)
A more advanced example of Multi-Phase training involves changing model parameters between training phases. For instance, you might want to switch the learning rate scheduler from CosineDecayLR
to ConstantLR
. To accomplish this, you need to create two instances of the Trainer
and carefully manage checkpoint loading between phases to account for the changes in model parameters.
In the example below, please note that the model, optimizer, and other parameters are similar to those in the previous example. These parameters have been omitted to simplify the example.
trainer:
- trainer:
init:
...
schedulers:
- CosineDecayLR:
initial_learning_rate: 3.0e-5
end_learning_rate: 3.0e-6
total_iters: 528
- trainer:
init:
...
schedulers:
- ConstantLR:
learning_rate: 1.0e-6
callbacks:
- LoadCheckpointStates:
load_checkpoint_states: "model,grad_scaler,optimizer,global_step"
In this example, each Trainer
constructs and compiles a model where in the second phase we changed the scheduler to ConstantLR
, so to avoid any issues with checkpoint loading we specify which parameters needs to be loaded. For further reading please follow Checkpointing.
Caveats
When running Multi-Phase training using Python API, you may hit an issue:
RuntimeError: Cannot instantiate multiple backends. A backend with type CSX has already been instantiated.
Please ensure that when you construct a Trainer
, you only instantiate a single backend. For example:
backend = cstorch.backend(
"CSX",
...
)
trainer1 = Trainer(
backend=backend,
...
)
trainer2 = Trainer(
backend=backend,
...
)
Conclusion
This tutorial showcases some of the use cases where Multi-Phase training can be applied. However, you are not limited to these examples and can construct as many Trainers as you need, combining different models, schedulers, optimizers, dataloaders, and more.