-
Notifications
You must be signed in to change notification settings - Fork 0
TransformerDecoder: optional positional encoding and final matmul #93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,12 +17,11 @@ | |
|
|
||
| import torch | ||
| from torch import nn, Tensor | ||
| import torch.nn.functional as F | ||
|
|
||
| from dataclasses import dataclass, field | ||
| from typing import List, Optional, Tuple, TypedDict, Union | ||
| from typing import List, Optional, Tuple, TypedDict, Union, NotRequired | ||
|
|
||
| from i6_models.config import ModelConfiguration | ||
| from i6_models.config import ModelConfiguration, ModuleFactoryV1 | ||
| from i6_models.parts.conformer import ( | ||
| ConformerMHSARelPosV1, | ||
| ConformerPositionwiseFeedForwardV2, | ||
|
|
@@ -128,6 +127,56 @@ def forward( | |
| return labels, {**state, "module_states": new_states} | ||
|
|
||
|
|
||
| @dataclass | ||
| class SinusoidalPositionalEncodingV1Config(ModelConfiguration): | ||
| """ | ||
| Attributes: | ||
| embedding_dim: embedding dimension | ||
| """ | ||
|
|
||
| embedding_dim = int | ||
|
|
||
|
|
||
| class PositionalEncodingV1State(TypedDict): | ||
| """ | ||
| State for some positional encoding. | ||
| """ | ||
|
|
||
| pos: Tensor | ||
|
|
||
|
|
||
| class SinusoidalPositionalEncodingV1(nn.Module, ModuleWithState[PositionalEncodingV1State]): | ||
| """ | ||
| Computes and applies a sinusoidal positional encoding. | ||
| """ | ||
|
|
||
| def __init__(self, cfg: SinusoidalPositionalEncodingV1Config): | ||
| super().__init__() | ||
|
|
||
| self.embed_dim = cfg.embedding_dim | ||
|
|
||
| def forward(self, inputs: Tensor, labels: Tensor, lengths: Tensor, state: PositionalEncodingV1State): | ||
| """ | ||
| Apply sinusoidal positional encoding on the inputs. | ||
|
|
||
| :param inputs: input embeddings | ||
| :param labels: input labels | ||
| :param lengths: input lengths | ||
| :param state: current state of positional encoding. | ||
| """ | ||
| sinus_pe = ConformerMHSARelPosV1._sinusoidal_pe( | ||
| torch.arange(labels.shape[-1], device=labels.device) + state["pos"], self.embed_dim | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have the constraint |
||
| ) | ||
| output = inputs + sinus_pe.unsqueeze(0) | ||
|
|
||
| new_state: PositionalEncodingV1State = {"pos": state["pos"] + lengths.max()} | ||
|
|
||
| return output, new_state | ||
|
|
||
| def get_initial_state(self) -> PositionalEncodingV1State: | ||
| return {"pos": Tensor(0, dtype=torch.int32)} | ||
|
|
||
|
|
||
| @dataclass | ||
| class TransformerDecoderV1Config(ModelConfiguration): | ||
| """ | ||
|
|
@@ -141,6 +190,8 @@ class TransformerDecoderV1Config(ModelConfiguration): | |
| logits_bias: Whether to add a bias to the output logits. | ||
| Usually False is a good choice. | ||
| share_embedding: Whether to share the input and output embedding. | ||
| positional_encoding: optionally apply some positional encoding to the input embeddings. | ||
| output_linear_projection: Whether to apply a linear projection on the model output to 'num_output' dimension. | ||
| """ | ||
|
|
||
| block_cfg: TransformerDecoderBlockV1Config | ||
|
|
@@ -150,13 +201,15 @@ class TransformerDecoderV1Config(ModelConfiguration): | |
| num_output: int | ||
| logits_bias: bool | ||
| share_embedding: bool | ||
| positional_encoding: Optional[ModuleFactoryV1] | ||
| output_linear_projection: bool = True | ||
|
|
||
|
|
||
| class TransformerDecoderV1State(TypedDict): | ||
| """Recurrent state of the transformer decoder.""" | ||
|
|
||
| block_state: List[TransformerDecoderBlockV1State] | ||
| pos: Tensor | ||
| pos_state: NotRequired[PositionalEncodingV1State] | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, maybe should not change names and type as this breaks existing setups. |
||
|
|
||
|
|
||
| class TransformerDecoderV1(nn.Module, ModuleWithState[TransformerDecoderV1State]): | ||
|
|
@@ -184,19 +237,33 @@ def __init__(self, cfg: TransformerDecoderV1Config): | |
| ) | ||
| self.out_norm = nn.LayerNorm(self.model_dim) | ||
| self.share_embedding = cfg.share_embedding | ||
| if cfg.share_embedding: | ||
| assert not cfg.logits_bias, "Cannot use logits bias with shared embedding" | ||
| nn.init.xavier_uniform_(self.input_embedding.weight) # bad convergence with default init | ||
|
|
||
| cfg.positional_encoding = None | ||
| if cfg.positional_encoding is not None: | ||
| self.positional_encoding = cfg.positional_encoding() | ||
|
|
||
| self.output_linear_projection = cfg.output_linear_projection | ||
|
|
||
| if not self.output_linear_projection: | ||
| self.out_logits = nn.Identity() | ||
| else: | ||
| self.out_logits = nn.Linear(self.model_dim, cfg.num_output, bias=cfg.logits_bias) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just realize, this sharing is weird. I would always set Also, And then you would also allow to have logits_bias=True with share_embedding=True. |
||
|
|
||
| if cfg.share_embedding and self.output_linear_projection: | ||
| self.out_logits.weight = self.input_embedding.weight | ||
| nn.init.xavier_uniform_(self.input_embedding.weight) # bad convergence with default init | ||
|
|
||
| def get_initial_state(self) -> TransformerDecoderV1State: | ||
| """:return: initial decoder state""" | ||
| return { | ||
| state: TransformerDecoderV1State = { | ||
| "block_state": [block.get_initial_state() for block in self.module_list], | ||
| "pos": torch.tensor(0, dtype=torch.int32), | ||
| } | ||
|
|
||
| if self.positional_encoding is not None: | ||
| state["pos_state"] = self.positional_encoding.get_initial_state() | ||
|
|
||
| return state | ||
|
|
||
| def transform_encoder_output( | ||
| self, | ||
| encoder_output: Tensor, | ||
|
|
@@ -228,26 +295,26 @@ def forward( | |
| - `enc_out, enc_out_mask = forward_some_encoder(...)` and | ||
| - `s = get_initial_state()`. | ||
| """ | ||
| new_state: TransformerDecoderV1State = { | ||
| **state, | ||
| } | ||
|
|
||
| x = self.input_embedding(labels) * self.input_embedding_scale | ||
| sinus_pe = ConformerMHSARelPosV1._sinusoidal_pe( | ||
| torch.arange(labels.shape[-1], device=labels.device) + state["pos"], self.model_dim | ||
| ) | ||
| x = x + sinus_pe.unsqueeze(0) | ||
|
|
||
| if self.positional_encoding is not None: | ||
| x, new_pos_state = self.positional_encoding(x, labels, labels_lens, state["pos"]) | ||
| new_state["pos_state"] = new_pos_state | ||
|
|
||
| x = self.input_dropout(x) | ||
|
|
||
| output = x | ||
| new_block_states = [] | ||
| for block, block_state in zip(self.module_list, state["block_state"]): | ||
| output, new_block_state = block(output, labels_lens, block_state) | ||
| new_block_states.append(new_block_state) | ||
| new_state: TransformerDecoderV1State = { | ||
| **state, | ||
| "block_state": new_block_states, | ||
| "pos": state["pos"] + labels_lens.max(), | ||
| } | ||
| new_state["block_state"] = new_block_states | ||
|
|
||
| output = self.out_norm(output) | ||
| output_logits = ( | ||
| F.linear(output, self.input_embedding.weight, None) if self.share_embedding else self.out_logits(output) | ||
| ) | ||
| return output_logits, new_state | ||
| output = self.out_logits(output) | ||
|
|
||
| return output, new_state | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe should be moved to
primitives?