use_transformer_initialization
set to False
.
d_kv
and num_heads
between the proxy model and the target model may cause instability.
mup_base_d_model
(Required to enable μP):The d_model of the base model in μP transfer used to calculate the necessary multipliers.
mup_base_d_ff
(Required to enable μP):The d_ff of the base model in μP transfer used to calculate the necessary multipliers.
mup_base_d_kv
(Highly Experimental):The d_kv of the base model in μP transfer used to calculate the necessary multipliers. Only use when varying d_kv alongside d_model.
embeddings_alpha
:Scales the embedding hidden states (i.e. the tensor after embeddings & embedding layer norm are applied). The embeddings are scaled by embeddings_alpha * sqrt(d_model)
. Recommended to tune for stabilizing gradient flow during μP training.
output_logits_alpha`
:Constant applied to the output logits scalar in μP training. The output logits are scaled by output_logits_alpha * mup_base_d_model/d_model
. Recommended to tune for stabilizing output logits in μP training.
scale_encoder_qk_dot_by_d
:Scales encoder attention QK dot product by d instead of sqrt(d). Must be enabled for μP training.
scale_decoder_qk_dot_by_d
:Scales decoder attention QK dot product by d instead of sqrt(d). Must be enabled for μP training.
encoder_attention_logits_alpha
:Additionally scales the encoder attention QK dot product by the specified value. Recommended to tune for stabilizing attention logits in μP training.
decoder_attention_logits_alpha
:Additionally scales the decoder attention QK dot product by the specified value. Recommended to tune for stabilizing attention logits in μP training.
scale_output_logits_by_d
:Scales the output logits in μP by mup_base_d_model/d_model
if True
and sqrt(mup_base_d_model/d_model)
if False
. Unlike other models, the default for the μP implementation in this model is False
.
embedding
: Targets the embedding weights.
encoder_qkv_projection
: Targets the Q, K, V projection dense layers in the encoder.
encoder_output_projection
: Targets the output projection dense layer in the encoder.
encoder_input_ffn
: Targets the first of the two FFN blocks in the encoder.
encoder_output_ffn
: Targets the final FFN block in the encoder.
decoder_qkv_projection
: Targets the Q, K, V projection dense layers in the decoder.
decoder_output_projection
: Targets the output projection dense layer in the decoder.
decoder_input_ffn
: Targets the first of the two FFN blocks in the decoder.
decoder_output_ffn
: Targets the final FFN block in the decoder.
lm_head
: Targets the weights of the lm_head.