Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 89 additions & 22 deletions i6_models/assemblies/transformer/transformer_decoder_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@

import torch
from torch import nn, Tensor
import torch.nn.functional as F

from dataclasses import dataclass, field
from typing import List, Optional, Tuple, TypedDict, Union
from typing import List, Optional, Tuple, TypedDict, Union, NotRequired

from i6_models.config import ModelConfiguration
from i6_models.config import ModelConfiguration, ModuleFactoryV1
from i6_models.parts.conformer import (
ConformerMHSARelPosV1,
ConformerPositionwiseFeedForwardV2,
Expand Down Expand Up @@ -128,6 +127,56 @@ def forward(
return labels, {**state, "module_states": new_states}


@dataclass
class SinusoidalPositionalEncodingV1Config(ModelConfiguration):
"""
Attributes:
embedding_dim: embedding dimension
"""

embedding_dim = int


class PositionalEncodingV1State(TypedDict):
"""
State for some positional encoding.
"""

pos: Tensor


class SinusoidalPositionalEncodingV1(nn.Module, ModuleWithState[PositionalEncodingV1State]):
"""
Computes and applies a sinusoidal positional encoding.
"""

def __init__(self, cfg: SinusoidalPositionalEncodingV1Config):
super().__init__()

self.embed_dim = cfg.embedding_dim

def forward(self, inputs: Tensor, labels: Tensor, lengths: Tensor, state: PositionalEncodingV1State):
"""
Apply sinusoidal positional encoding on the inputs.

:param inputs: input embeddings
:param labels: input labels
:param lengths: input lengths
:param state: current state of positional encoding.
"""
sinus_pe = ConformerMHSARelPosV1._sinusoidal_pe(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe should be moved to primitives?

torch.arange(labels.shape[-1], device=labels.device) + state["pos"], self.embed_dim
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have the constraint labels.shape[-1] == lenghts.max()? If so labels input can be removed.
Or do we only return sinus_pre.unsqueeze(0) and apply the addition later?

)
output = inputs + sinus_pe.unsqueeze(0)

new_state: PositionalEncodingV1State = {"pos": state["pos"] + lengths.max()}

return output, new_state

def get_initial_state(self) -> PositionalEncodingV1State:
return {"pos": Tensor(0, dtype=torch.int32)}


@dataclass
class TransformerDecoderV1Config(ModelConfiguration):
"""
Expand All @@ -141,6 +190,8 @@ class TransformerDecoderV1Config(ModelConfiguration):
logits_bias: Whether to add a bias to the output logits.
Usually False is a good choice.
share_embedding: Whether to share the input and output embedding.
positional_encoding: optionally apply some positional encoding to the input embeddings.
output_linear_projection: Whether to apply a linear projection on the model output to 'num_output' dimension.
"""

block_cfg: TransformerDecoderBlockV1Config
Expand All @@ -150,13 +201,15 @@ class TransformerDecoderV1Config(ModelConfiguration):
num_output: int
logits_bias: bool
share_embedding: bool
positional_encoding: Optional[ModuleFactoryV1]
output_linear_projection: bool = True


class TransformerDecoderV1State(TypedDict):
"""Recurrent state of the transformer decoder."""

block_state: List[TransformerDecoderBlockV1State]
pos: Tensor
pos_state: NotRequired[PositionalEncodingV1State]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, maybe should not change names and type as this breaks existing setups.



class TransformerDecoderV1(nn.Module, ModuleWithState[TransformerDecoderV1State]):
Expand Down Expand Up @@ -184,19 +237,33 @@ def __init__(self, cfg: TransformerDecoderV1Config):
)
self.out_norm = nn.LayerNorm(self.model_dim)
self.share_embedding = cfg.share_embedding
if cfg.share_embedding:
assert not cfg.logits_bias, "Cannot use logits bias with shared embedding"
nn.init.xavier_uniform_(self.input_embedding.weight) # bad convergence with default init

cfg.positional_encoding = None
if cfg.positional_encoding is not None:
self.positional_encoding = cfg.positional_encoding()

self.output_linear_projection = cfg.output_linear_projection

if not self.output_linear_projection:
self.out_logits = nn.Identity()
else:
self.out_logits = nn.Linear(self.model_dim, cfg.num_output, bias=cfg.logits_bias)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realize, this sharing is weird. I would always set self.out_logits. If sharing, you can just do self.out_logits.weights = self.input_embedding.weight. That would simplify the other code.

Also, self.out_logits should always be set (be None if not used). But with my suggestion, you don't need to care about this.

And then you would also allow to have logits_bias=True with share_embedding=True.


if cfg.share_embedding and self.output_linear_projection:
self.out_logits.weight = self.input_embedding.weight
nn.init.xavier_uniform_(self.input_embedding.weight) # bad convergence with default init

def get_initial_state(self) -> TransformerDecoderV1State:
""":return: initial decoder state"""
return {
state: TransformerDecoderV1State = {
"block_state": [block.get_initial_state() for block in self.module_list],
"pos": torch.tensor(0, dtype=torch.int32),
}

if self.positional_encoding is not None:
state["pos_state"] = self.positional_encoding.get_initial_state()

return state

def transform_encoder_output(
self,
encoder_output: Tensor,
Expand Down Expand Up @@ -228,26 +295,26 @@ def forward(
- `enc_out, enc_out_mask = forward_some_encoder(...)` and
- `s = get_initial_state()`.
"""
new_state: TransformerDecoderV1State = {
**state,
}

x = self.input_embedding(labels) * self.input_embedding_scale
sinus_pe = ConformerMHSARelPosV1._sinusoidal_pe(
torch.arange(labels.shape[-1], device=labels.device) + state["pos"], self.model_dim
)
x = x + sinus_pe.unsqueeze(0)

if self.positional_encoding is not None:
x, new_pos_state = self.positional_encoding(x, labels, labels_lens, state["pos"])
new_state["pos_state"] = new_pos_state

x = self.input_dropout(x)

output = x
new_block_states = []
for block, block_state in zip(self.module_list, state["block_state"]):
output, new_block_state = block(output, labels_lens, block_state)
new_block_states.append(new_block_state)
new_state: TransformerDecoderV1State = {
**state,
"block_state": new_block_states,
"pos": state["pos"] + labels_lens.max(),
}
new_state["block_state"] = new_block_states

output = self.out_norm(output)
output_logits = (
F.linear(output, self.input_embedding.weight, None) if self.share_embedding else self.out_logits(output)
)
return output_logits, new_state
output = self.out_logits(output)

return output, new_state