diff --git a/i6_models/parts/conformer/mhsa_rel_pos.py b/i6_models/parts/conformer/mhsa_rel_pos.py index 7d2f8cc7..6f05dc26 100644 --- a/i6_models/parts/conformer/mhsa_rel_pos.py +++ b/i6_models/parts/conformer/mhsa_rel_pos.py @@ -11,8 +11,8 @@ import torch.nn.functional as F from i6_models.config import ModelConfiguration -from i6_models.util import compat from i6_models.parts.dropout import BroadcastDropout +from i6_models.util import compat @dataclass @@ -173,10 +173,16 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to rel_pos_embeddings = self.rel_pos_embeddings[final_mat] # [T, T', pos_emb_dim] else: - rel_pos_embeddings = self._sinusoidal_pe( - torch.arange(time_dim_size - 1, -time_dim_size, -1, device=input_tensor.device, dtype=torch.float32), - self.pos_emb_dim, - ).view(1, 2 * time_dim_size - 1, self.pos_emb_dim) # [1, T+T'-1, pos_emb_dim] + rel_pos_embeddings = ( + self._sinusoidal_pe( + torch.arange( + time_dim_size - 1, -time_dim_size, -1, device=input_tensor.device, dtype=torch.float32 + ), + self.pos_emb_dim, + ) + .to(input_tensor.dtype) + .view(1, 2 * time_dim_size - 1, self.pos_emb_dim) + ) # [1, T+T'-1, pos_emb_dim] # dropout relative positional embeddings rel_pos_embeddings = self.pos_emb_dropout( @@ -195,31 +201,43 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to q_with_bias_u = q + self.pos_bias_u if self.with_pos_bias else q # [B, T, #heads, F'] q_with_bias_v = q + self.pos_bias_v if self.with_pos_bias else q - # attention matrix a and c - attn_ac = torch.einsum("bihf, bjhf -> bhij", q_with_bias_u, k) # [B, #heads, T, T'] - # attention matrix b and d attn_bd = torch.einsum( - "bihf, ijhf -> bhij", q_with_bias_v, rel_pos_embeddings + "bihf, ijhf -> bhij", + q_with_bias_v, + rel_pos_embeddings.to(device=q_with_bias_v.device, dtype=q_with_bias_v.dtype), ) # [B, #heads, T, T'] or [B, #heads, T, T+T'+1] - if not self.learnable_pos_emb: attn_bd = self._rel_shift_bhij(attn_bd, k_len=time_dim_size) # [B, #heads, T, T'] - attn = attn_ac + attn_bd + mask # [B, #heads, T, T'] - attn_scaled = attn * (math.sqrt(1.0 / float(self.embed_dim_per_head))) # [B, #heads, T, T'] + # We use attn_mask to add BD matrix to attention scores. + # + # Inside torch's SDPA the mask is added after regular scaling, so to get correct + # results, we need to apply the scaling here before passing to SDPA. + # + # See for reference: + # https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html + attn_bd_mask = attn_bd + mask + scale = math.sqrt(1.0 / float(self.embed_dim_per_head)) + attn_bd_mask_scaled = attn_bd_mask * scale - # softmax and dropout - attn_output_weights = self.att_weights_dropout(F.softmax(attn_scaled, dim=-1)) # [B, #heads, T, T'] - - # sequence of weighted sums over value sequence v = value_seq.view(batch_dim_size, -1, self.num_heads, self.embed_dim_per_head) # [B, T, H, F'] - attn_output = torch.einsum("bhij, bjhf -> bihf", attn_output_weights, v).reshape( - batch_dim_size, -1, self.embed_dim - ) - output_tensor = self.out_proj(attn_output) + # Use torch's SDPA for efficiency. + # + # The attention matrices a and c are computed inside torch's sdpa. + attn_output = F.scaled_dot_product_attention( + q_with_bias_u.transpose(-3, -2), # [B, #heads, T, F'] + k.transpose(-3, -2), # [B, #heads, T', F'] + v.transpose(-3, -2), # [B, #heads, T, F'] + attn_mask=attn_bd_mask_scaled, # [B, #heads, T, T'] + dropout_p=self.att_weights_dropout.p if self.training else 0.0, + scale=scale, + ) # [B, #heads, T, F'] + attn_output = attn_output.transpose(-3, -2).flatten(-2) # [B, T, F'] + assert attn_output.shape[-1] == self.embed_dim + output_tensor = self.out_proj(attn_output) output_tensor = self.dropout(output_tensor) return output_tensor # [B,T,F]