Note: Beta support for GPTJ indicates that we have confirmed the model’s μP functionality using coordinate checks, but have yet to perform long convergence runs.

Model Params

  • mup_base_hidden_size (Required to enable μP):The hidden size of the proxy model in μP transfer used to calculate the necessary multipliers.

  • mup_base_filter_size (Required to enable μP):The filter size of the proxy model in μP transfer used to calculate the necessary multipliers.

  • embeddings_scale:Scales the embedding hidden states (i.e. the tensor after embeddings & embedding layer norm are applied). Recommended to tune for stabilizing gradient flow during μP training.

  • output_logits_alpha:Constant applied to the output logits scalar in μP training. The output logits are scaled by output_logits_alpha * mup_base_hidden_size/hidden_size. Recommended to tune for stabilizing output logits in μP training.

  • scale_qk_dot_by_d:Scales attention QK dot product by d instead of sqrt(d). Must be enabled for muP training.

  • attention_logits_alpha:Scales the attention QK dot product by the specified value. Recommended to tune for stabilizing attention logits in muP training.

  • scale_output_logits_by_d:Scales the output logits in μP by mup_base_hidden_size/hidden_size if True and sqrt(mup_base_hidden_size/hidden_size) if False. It is traditionally set to True in the μP implementation of this model.

Supported LR Adjustment Groups

  • embedding: Targets the embedding weights.

  • decoder_attention: Targets the dense layers in the decoder (Q, K, V, Output projections)

  • decoder_input_ffn: Targets the first of the two FFN blocks in the decoder.

  • decoder_output_ffn: Targets the final FFN block in the decoder.

Example Configuration

model:
  ...
  # muP
  scale_qk_dot_by_d: True
  scale_output_logits_by_d: True
  mup_base_hidden_size: ...
  mup_base_filter_size: ...
  output_logits_alpha: ...
  attention_logits_alpha: ...
  embeddings_scale: ...

optimizer:
  ...
  adjust_learning_rate:
    embedding: ...
    decoder_input_ffn: ...
    ...

Note how not all the LR adjustment groups were specified since there is only a need to provide these values in the config if you would like to override the traditional μP LR scaling.