Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 38 additions & 20 deletions i6_models/parts/conformer/mhsa_rel_pos.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import torch.nn.functional as F

from i6_models.config import ModelConfiguration
from i6_models.util import compat
from i6_models.parts.dropout import BroadcastDropout
from i6_models.util import compat


@dataclass
Expand Down Expand Up @@ -173,10 +173,16 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to

rel_pos_embeddings = self.rel_pos_embeddings[final_mat] # [T, T', pos_emb_dim]
else:
rel_pos_embeddings = self._sinusoidal_pe(
torch.arange(time_dim_size - 1, -time_dim_size, -1, device=input_tensor.device, dtype=torch.float32),
self.pos_emb_dim,
).view(1, 2 * time_dim_size - 1, self.pos_emb_dim) # [1, T+T'-1, pos_emb_dim]
rel_pos_embeddings = (
self._sinusoidal_pe(
torch.arange(
time_dim_size - 1, -time_dim_size, -1, device=input_tensor.device, dtype=torch.float32
),
self.pos_emb_dim,
)
.to(input_tensor.dtype)
.view(1, 2 * time_dim_size - 1, self.pos_emb_dim)
) # [1, T+T'-1, pos_emb_dim]

# dropout relative positional embeddings
rel_pos_embeddings = self.pos_emb_dropout(
Expand All @@ -195,31 +201,43 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to
q_with_bias_u = q + self.pos_bias_u if self.with_pos_bias else q # [B, T, #heads, F']
q_with_bias_v = q + self.pos_bias_v if self.with_pos_bias else q

# attention matrix a and c
attn_ac = torch.einsum("bihf, bjhf -> bhij", q_with_bias_u, k) # [B, #heads, T, T']

# attention matrix b and d
attn_bd = torch.einsum(
"bihf, ijhf -> bhij", q_with_bias_v, rel_pos_embeddings
"bihf, ijhf -> bhij",
q_with_bias_v,
rel_pos_embeddings.to(device=q_with_bias_v.device, dtype=q_with_bias_v.dtype),
) # [B, #heads, T, T'] or [B, #heads, T, T+T'+1]

if not self.learnable_pos_emb:
attn_bd = self._rel_shift_bhij(attn_bd, k_len=time_dim_size) # [B, #heads, T, T']

attn = attn_ac + attn_bd + mask # [B, #heads, T, T']
attn_scaled = attn * (math.sqrt(1.0 / float(self.embed_dim_per_head))) # [B, #heads, T, T']
# We use attn_mask to add BD matrix to attention scores.
#
# Inside torch's SDPA the mask is added after regular scaling, so to get correct
# results, we need to apply the scaling here before passing to SDPA.
#
# See for reference:
# https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
attn_bd_mask = attn_bd + mask
scale = math.sqrt(1.0 / float(self.embed_dim_per_head))
attn_bd_mask_scaled = attn_bd_mask * scale

# softmax and dropout
attn_output_weights = self.att_weights_dropout(F.softmax(attn_scaled, dim=-1)) # [B, #heads, T, T']

# sequence of weighted sums over value sequence
v = value_seq.view(batch_dim_size, -1, self.num_heads, self.embed_dim_per_head) # [B, T, H, F']
attn_output = torch.einsum("bhij, bjhf -> bihf", attn_output_weights, v).reshape(
batch_dim_size, -1, self.embed_dim
)

output_tensor = self.out_proj(attn_output)
# Use torch's SDPA for efficiency.
#
# The attention matrices a and c are computed inside torch's sdpa.
attn_output = F.scaled_dot_product_attention(
q_with_bias_u.transpose(-3, -2), # [B, #heads, T, F']
k.transpose(-3, -2), # [B, #heads, T', F']
v.transpose(-3, -2), # [B, #heads, T, F']
attn_mask=attn_bd_mask_scaled, # [B, #heads, T, T']
dropout_p=self.att_weights_dropout.p if self.training else 0.0,
scale=scale,
) # [B, #heads, T, F']
attn_output = attn_output.transpose(-3, -2).flatten(-2) # [B, T, F']
assert attn_output.shape[-1] == self.embed_dim

output_tensor = self.out_proj(attn_output)
output_tensor = self.dropout(output_tensor)

return output_tensor # [B,T,F]
Expand Down