From 0fa1a41162533e5a04a5693f6410d21897021bda Mon Sep 17 00:00:00 2001 From: Gerstenberger Date: Thu, 29 Jan 2026 11:18:41 +0100 Subject: [PATCH 1/4] TransformerDecoder: optional positional encoding and final matmul --- .../transformer/transformer_decoder_v1.py | 45 +++++++++++++------ 1 file changed, 32 insertions(+), 13 deletions(-) diff --git a/i6_models/assemblies/transformer/transformer_decoder_v1.py b/i6_models/assemblies/transformer/transformer_decoder_v1.py index f4d0e1f2..38710c56 100644 --- a/i6_models/assemblies/transformer/transformer_decoder_v1.py +++ b/i6_models/assemblies/transformer/transformer_decoder_v1.py @@ -20,7 +20,7 @@ import torch.nn.functional as F from dataclasses import dataclass, field -from typing import List, Optional, Tuple, TypedDict, Union +from typing import List, Optional, Tuple, TypedDict, Union, NotRequired from i6_models.config import ModelConfiguration from i6_models.parts.conformer import ( @@ -141,6 +141,8 @@ class TransformerDecoderV1Config(ModelConfiguration): logits_bias: Whether to add a bias to the output logits. Usually False is a good choice. share_embedding: Whether to share the input and output embedding. + use_positional_encoding: use a sinus positional encoding on the initial input + do_output_embedding_matmul: apply the final model output x output embedding matmul """ block_cfg: TransformerDecoderBlockV1Config @@ -150,13 +152,15 @@ class TransformerDecoderV1Config(ModelConfiguration): num_output: int logits_bias: bool share_embedding: bool + use_positional_encoding: bool = True + do_output_embedding_matmul: bool = True class TransformerDecoderV1State(TypedDict): """Recurrent state of the transformer decoder.""" block_state: List[TransformerDecoderBlockV1State] - pos: Tensor + pos: NotRequired[Tensor] class TransformerDecoderV1(nn.Module, ModuleWithState[TransformerDecoderV1State]): @@ -190,13 +194,20 @@ def __init__(self, cfg: TransformerDecoderV1Config): else: self.out_logits = nn.Linear(self.model_dim, cfg.num_output, bias=cfg.logits_bias) + self.use_positional_encoding = cfg.use_positional_encoding + self.do_output_embedding_matmul = cfg.do_output_embedding_matmul + def get_initial_state(self) -> TransformerDecoderV1State: """:return: initial decoder state""" - return { + state: TransformerDecoderV1State = { "block_state": [block.get_initial_state() for block in self.module_list], - "pos": torch.tensor(0, dtype=torch.int32), } + if self.use_positional_encoding: + state["pos"] = torch.tensor(0, dtype=torch.int32) + + return state + def transform_encoder_output( self, encoder_output: Tensor, @@ -229,10 +240,13 @@ def forward( - `s = get_initial_state()`. """ x = self.input_embedding(labels) * self.input_embedding_scale - sinus_pe = ConformerMHSARelPosV1._sinusoidal_pe( - torch.arange(labels.shape[-1], device=labels.device) + state["pos"], self.model_dim - ) - x = x + sinus_pe.unsqueeze(0) + + if self.use_positional_encoding: + sinus_pe = ConformerMHSARelPosV1._sinusoidal_pe( + torch.arange(labels.shape[-1], device=labels.device) + state["pos"], self.model_dim + ) + x = x + sinus_pe.unsqueeze(0) + x = self.input_dropout(x) output = x @@ -243,11 +257,16 @@ def forward( new_state: TransformerDecoderV1State = { **state, "block_state": new_block_states, - "pos": state["pos"] + labels_lens.max(), } + if self.use_positional_encoding: + new_state["pos"] = state["pos"] + labels_lens.max() + output = self.out_norm(output) - output_logits = ( - F.linear(output, self.input_embedding.weight, None) if self.share_embedding else self.out_logits(output) - ) - return output_logits, new_state + + if self.do_output_embedding_matmul: + output = ( + F.linear(output, self.input_embedding.weight, None) if self.share_embedding else self.out_logits(output) + ) + + return output, new_state From a2897e544d1c33f47192419d3b36aa9817f89c33 Mon Sep 17 00:00:00 2001 From: Gerstenberger Date: Wed, 11 Feb 2026 17:03:56 +0100 Subject: [PATCH 2/4] initial proposal to commentary --- .../transformer/transformer_decoder_v1.py | 106 +++++++++++++----- 1 file changed, 75 insertions(+), 31 deletions(-) diff --git a/i6_models/assemblies/transformer/transformer_decoder_v1.py b/i6_models/assemblies/transformer/transformer_decoder_v1.py index 38710c56..17ddf036 100644 --- a/i6_models/assemblies/transformer/transformer_decoder_v1.py +++ b/i6_models/assemblies/transformer/transformer_decoder_v1.py @@ -17,12 +17,11 @@ import torch from torch import nn, Tensor -import torch.nn.functional as F from dataclasses import dataclass, field from typing import List, Optional, Tuple, TypedDict, Union, NotRequired -from i6_models.config import ModelConfiguration +from i6_models.config import ModelConfiguration, ModuleFactoryV1 from i6_models.parts.conformer import ( ConformerMHSARelPosV1, ConformerPositionwiseFeedForwardV2, @@ -128,6 +127,52 @@ def forward( return labels, {**state, "module_states": new_states} +@dataclass +class SinusoidalPositionalEncodingV1Config(ModelConfiguration): + """ + Attributes: + embedding_dim: embedding dimension + """ + embedding_dim = int + + +class PositionalEncodingV1State(TypedDict): + """ + State for some positional encoding. + """ + pos: Tensor + + +class SinusoidalPositionalEncodingV1(nn.Module, ModuleWithState[PositionalEncodingV1State]): + """ + Computes and applies a sinusoidal positional encoding. + """ + def __init__(self, cfg: SinusoidalPositionalEncodingV1Config): + super().__init__() + + self.embed_dim = cfg.embedding_dim + + def forward(self, inputs: Tensor, lengths: Tensor, state: PositionalEncodingV1State): + """ + Apply sinusoidal positional encoding on the inputs. + + :param inputs: tensor to apply the positional encoding on + :param lengths: input lengths + :param state: current state of positional encoding. + """ + sinus_pe = ConformerMHSARelPosV1._sinusoidal_pe( + torch.arange(inputs.shape[-1], device=inputs.device) + state["pos"], self.embed_dim + ) + output = inputs + sinus_pe.unsqueeze(0) + + new_state: PositionalEncodingV1State = {"pos": state["pos"] + lengths.max()} + + return output, new_state + + def get_initial_state(self) -> PositionalEncodingV1State: + return {"pos": Tensor(0, dtype=torch.int32)} + + @dataclass class TransformerDecoderV1Config(ModelConfiguration): """ @@ -141,8 +186,8 @@ class TransformerDecoderV1Config(ModelConfiguration): logits_bias: Whether to add a bias to the output logits. Usually False is a good choice. share_embedding: Whether to share the input and output embedding. - use_positional_encoding: use a sinus positional encoding on the initial input - do_output_embedding_matmul: apply the final model output x output embedding matmul + positional_encoding: optionally apply some positional encoding to the input embeddings. + output_linear_projection: Whether to apply a linear projection on the model output to 'num_output' dimension. """ block_cfg: TransformerDecoderBlockV1Config @@ -152,15 +197,15 @@ class TransformerDecoderV1Config(ModelConfiguration): num_output: int logits_bias: bool share_embedding: bool - use_positional_encoding: bool = True - do_output_embedding_matmul: bool = True + positional_encoding: Optional[ModuleFactoryV1] + output_linear_projection: bool = True class TransformerDecoderV1State(TypedDict): """Recurrent state of the transformer decoder.""" block_state: List[TransformerDecoderBlockV1State] - pos: NotRequired[Tensor] + pos_state: NotRequired[PositionalEncodingV1State] class TransformerDecoderV1(nn.Module, ModuleWithState[TransformerDecoderV1State]): @@ -188,14 +233,21 @@ def __init__(self, cfg: TransformerDecoderV1Config): ) self.out_norm = nn.LayerNorm(self.model_dim) self.share_embedding = cfg.share_embedding - if cfg.share_embedding: - assert not cfg.logits_bias, "Cannot use logits bias with shared embedding" - nn.init.xavier_uniform_(self.input_embedding.weight) # bad convergence with default init + + cfg.positional_encoding = None + if cfg.positional_encoding is not None: + self.positional_encoding = cfg.positional_encoding() + + self.output_linear_projection = cfg.output_linear_projection + + if not self.output_linear_projection: + self.out_logits = nn.Identity() else: self.out_logits = nn.Linear(self.model_dim, cfg.num_output, bias=cfg.logits_bias) - self.use_positional_encoding = cfg.use_positional_encoding - self.do_output_embedding_matmul = cfg.do_output_embedding_matmul + if cfg.share_embedding and self.output_linear_projection: + self.out_logits.weight = self.input_embedding.weight + nn.init.xavier_uniform_(self.input_embedding.weight) # bad convergence with default init def get_initial_state(self) -> TransformerDecoderV1State: """:return: initial decoder state""" @@ -203,8 +255,8 @@ def get_initial_state(self) -> TransformerDecoderV1State: "block_state": [block.get_initial_state() for block in self.module_list], } - if self.use_positional_encoding: - state["pos"] = torch.tensor(0, dtype=torch.int32) + if self.positional_encoding is not None: + state["pos_state"] = self.positional_encoding.get_initial_state() return state @@ -239,13 +291,15 @@ def forward( - `enc_out, enc_out_mask = forward_some_encoder(...)` and - `s = get_initial_state()`. """ + new_state: TransformerDecoderV1State = { + **state, + } + x = self.input_embedding(labels) * self.input_embedding_scale - if self.use_positional_encoding: - sinus_pe = ConformerMHSARelPosV1._sinusoidal_pe( - torch.arange(labels.shape[-1], device=labels.device) + state["pos"], self.model_dim - ) - x = x + sinus_pe.unsqueeze(0) + if self.positional_encoding is not None: + x, new_pos_state = self.positional_encoding(x, labels_lens, state["pos"]) + new_state["pos_state"] = new_pos_state x = self.input_dropout(x) @@ -254,19 +308,9 @@ def forward( for block, block_state in zip(self.module_list, state["block_state"]): output, new_block_state = block(output, labels_lens, block_state) new_block_states.append(new_block_state) - new_state: TransformerDecoderV1State = { - **state, - "block_state": new_block_states, - } - - if self.use_positional_encoding: - new_state["pos"] = state["pos"] + labels_lens.max() + new_state["block_state"] = new_block_states output = self.out_norm(output) - - if self.do_output_embedding_matmul: - output = ( - F.linear(output, self.input_embedding.weight, None) if self.share_embedding else self.out_logits(output) - ) + output = self.out_logits(output) return output, new_state From 9ec338f597d41964c5bfbf50bcfc5c287ffb4c21 Mon Sep 17 00:00:00 2001 From: Gerstenberger Date: Wed, 11 Feb 2026 17:13:35 +0100 Subject: [PATCH 3/4] formating --- i6_models/assemblies/transformer/transformer_decoder_v1.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/i6_models/assemblies/transformer/transformer_decoder_v1.py b/i6_models/assemblies/transformer/transformer_decoder_v1.py index 17ddf036..4153763f 100644 --- a/i6_models/assemblies/transformer/transformer_decoder_v1.py +++ b/i6_models/assemblies/transformer/transformer_decoder_v1.py @@ -133,6 +133,7 @@ class SinusoidalPositionalEncodingV1Config(ModelConfiguration): Attributes: embedding_dim: embedding dimension """ + embedding_dim = int @@ -140,6 +141,7 @@ class PositionalEncodingV1State(TypedDict): """ State for some positional encoding. """ + pos: Tensor @@ -147,6 +149,7 @@ class SinusoidalPositionalEncodingV1(nn.Module, ModuleWithState[PositionalEncodi """ Computes and applies a sinusoidal positional encoding. """ + def __init__(self, cfg: SinusoidalPositionalEncodingV1Config): super().__init__() From 86bfe3e4156d1a0dd2db19e20db7f4c1dd628c87 Mon Sep 17 00:00:00 2001 From: Gerstenberger Date: Wed, 11 Feb 2026 17:23:14 +0100 Subject: [PATCH 4/4] fix input to _sinusoidal_pe --- .../assemblies/transformer/transformer_decoder_v1.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/i6_models/assemblies/transformer/transformer_decoder_v1.py b/i6_models/assemblies/transformer/transformer_decoder_v1.py index 4153763f..97179ea3 100644 --- a/i6_models/assemblies/transformer/transformer_decoder_v1.py +++ b/i6_models/assemblies/transformer/transformer_decoder_v1.py @@ -155,16 +155,17 @@ def __init__(self, cfg: SinusoidalPositionalEncodingV1Config): self.embed_dim = cfg.embedding_dim - def forward(self, inputs: Tensor, lengths: Tensor, state: PositionalEncodingV1State): + def forward(self, inputs: Tensor, labels: Tensor, lengths: Tensor, state: PositionalEncodingV1State): """ Apply sinusoidal positional encoding on the inputs. - :param inputs: tensor to apply the positional encoding on + :param inputs: input embeddings + :param labels: input labels :param lengths: input lengths :param state: current state of positional encoding. """ sinus_pe = ConformerMHSARelPosV1._sinusoidal_pe( - torch.arange(inputs.shape[-1], device=inputs.device) + state["pos"], self.embed_dim + torch.arange(labels.shape[-1], device=labels.device) + state["pos"], self.embed_dim ) output = inputs + sinus_pe.unsqueeze(0) @@ -301,7 +302,7 @@ def forward( x = self.input_embedding(labels) * self.input_embedding_scale if self.positional_encoding is not None: - x, new_pos_state = self.positional_encoding(x, labels_lens, state["pos"]) + x, new_pos_state = self.positional_encoding(x, labels, labels_lens, state["pos"]) new_state["pos_state"] = new_pos_state x = self.input_dropout(x)