diff --git a/i6_models/assemblies/transformer/transformer_decoder_v1.py b/i6_models/assemblies/transformer/transformer_decoder_v1.py index 773e1474..ab995857 100644 --- a/i6_models/assemblies/transformer/transformer_decoder_v1.py +++ b/i6_models/assemblies/transformer/transformer_decoder_v1.py @@ -128,7 +128,7 @@ class TransformerDecoderV1Config(ModelConfiguration): block_cfg: Configuration for TransformerDecoderV1. input_dropout: Dropout applied to the input embedding. input_embedding_scale: Scale applied to the input embedding. - Set to `None` to apply a (tuned) default. + Set to `None` to apply a default of sqrt(model_dim). num_blocks: Number of transformer blocks in the decoder. num_output: Number of output labels/vocab dim. logits_bias: Whether to add a bias to the output logits.