diff --git a/i6_models/assemblies/transformer/transformer_decoder_v1.py b/i6_models/assemblies/transformer/transformer_decoder_v1.py index f4d0e1f2..97179ea3 100644 --- a/i6_models/assemblies/transformer/transformer_decoder_v1.py +++ b/i6_models/assemblies/transformer/transformer_decoder_v1.py @@ -17,12 +17,11 @@ import torch from torch import nn, Tensor -import torch.nn.functional as F from dataclasses import dataclass, field -from typing import List, Optional, Tuple, TypedDict, Union +from typing import List, Optional, Tuple, TypedDict, Union, NotRequired -from i6_models.config import ModelConfiguration +from i6_models.config import ModelConfiguration, ModuleFactoryV1 from i6_models.parts.conformer import ( ConformerMHSARelPosV1, ConformerPositionwiseFeedForwardV2, @@ -128,6 +127,56 @@ def forward( return labels, {**state, "module_states": new_states} +@dataclass +class SinusoidalPositionalEncodingV1Config(ModelConfiguration): + """ + Attributes: + embedding_dim: embedding dimension + """ + + embedding_dim = int + + +class PositionalEncodingV1State(TypedDict): + """ + State for some positional encoding. + """ + + pos: Tensor + + +class SinusoidalPositionalEncodingV1(nn.Module, ModuleWithState[PositionalEncodingV1State]): + """ + Computes and applies a sinusoidal positional encoding. + """ + + def __init__(self, cfg: SinusoidalPositionalEncodingV1Config): + super().__init__() + + self.embed_dim = cfg.embedding_dim + + def forward(self, inputs: Tensor, labels: Tensor, lengths: Tensor, state: PositionalEncodingV1State): + """ + Apply sinusoidal positional encoding on the inputs. + + :param inputs: input embeddings + :param labels: input labels + :param lengths: input lengths + :param state: current state of positional encoding. + """ + sinus_pe = ConformerMHSARelPosV1._sinusoidal_pe( + torch.arange(labels.shape[-1], device=labels.device) + state["pos"], self.embed_dim + ) + output = inputs + sinus_pe.unsqueeze(0) + + new_state: PositionalEncodingV1State = {"pos": state["pos"] + lengths.max()} + + return output, new_state + + def get_initial_state(self) -> PositionalEncodingV1State: + return {"pos": Tensor(0, dtype=torch.int32)} + + @dataclass class TransformerDecoderV1Config(ModelConfiguration): """ @@ -141,6 +190,8 @@ class TransformerDecoderV1Config(ModelConfiguration): logits_bias: Whether to add a bias to the output logits. Usually False is a good choice. share_embedding: Whether to share the input and output embedding. + positional_encoding: optionally apply some positional encoding to the input embeddings. + output_linear_projection: Whether to apply a linear projection on the model output to 'num_output' dimension. """ block_cfg: TransformerDecoderBlockV1Config @@ -150,13 +201,15 @@ class TransformerDecoderV1Config(ModelConfiguration): num_output: int logits_bias: bool share_embedding: bool + positional_encoding: Optional[ModuleFactoryV1] + output_linear_projection: bool = True class TransformerDecoderV1State(TypedDict): """Recurrent state of the transformer decoder.""" block_state: List[TransformerDecoderBlockV1State] - pos: Tensor + pos_state: NotRequired[PositionalEncodingV1State] class TransformerDecoderV1(nn.Module, ModuleWithState[TransformerDecoderV1State]): @@ -184,19 +237,33 @@ def __init__(self, cfg: TransformerDecoderV1Config): ) self.out_norm = nn.LayerNorm(self.model_dim) self.share_embedding = cfg.share_embedding - if cfg.share_embedding: - assert not cfg.logits_bias, "Cannot use logits bias with shared embedding" - nn.init.xavier_uniform_(self.input_embedding.weight) # bad convergence with default init + + cfg.positional_encoding = None + if cfg.positional_encoding is not None: + self.positional_encoding = cfg.positional_encoding() + + self.output_linear_projection = cfg.output_linear_projection + + if not self.output_linear_projection: + self.out_logits = nn.Identity() else: self.out_logits = nn.Linear(self.model_dim, cfg.num_output, bias=cfg.logits_bias) + if cfg.share_embedding and self.output_linear_projection: + self.out_logits.weight = self.input_embedding.weight + nn.init.xavier_uniform_(self.input_embedding.weight) # bad convergence with default init + def get_initial_state(self) -> TransformerDecoderV1State: """:return: initial decoder state""" - return { + state: TransformerDecoderV1State = { "block_state": [block.get_initial_state() for block in self.module_list], - "pos": torch.tensor(0, dtype=torch.int32), } + if self.positional_encoding is not None: + state["pos_state"] = self.positional_encoding.get_initial_state() + + return state + def transform_encoder_output( self, encoder_output: Tensor, @@ -228,11 +295,16 @@ def forward( - `enc_out, enc_out_mask = forward_some_encoder(...)` and - `s = get_initial_state()`. """ + new_state: TransformerDecoderV1State = { + **state, + } + x = self.input_embedding(labels) * self.input_embedding_scale - sinus_pe = ConformerMHSARelPosV1._sinusoidal_pe( - torch.arange(labels.shape[-1], device=labels.device) + state["pos"], self.model_dim - ) - x = x + sinus_pe.unsqueeze(0) + + if self.positional_encoding is not None: + x, new_pos_state = self.positional_encoding(x, labels, labels_lens, state["pos"]) + new_state["pos_state"] = new_pos_state + x = self.input_dropout(x) output = x @@ -240,14 +312,9 @@ def forward( for block, block_state in zip(self.module_list, state["block_state"]): output, new_block_state = block(output, labels_lens, block_state) new_block_states.append(new_block_state) - new_state: TransformerDecoderV1State = { - **state, - "block_state": new_block_states, - "pos": state["pos"] + labels_lens.max(), - } + new_state["block_state"] = new_block_states output = self.out_norm(output) - output_logits = ( - F.linear(output, self.input_embedding.weight, None) if self.share_embedding else self.out_logits(output) - ) - return output_logits, new_state + output = self.out_logits(output) + + return output, new_state