From 58699f60ef8fe866714cdf22f92d4f843e6de211 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Wed, 3 Dec 2025 14:17:13 +0100 Subject: [PATCH 1/4] Use torch's fused SDPA for attention computation --- i6_models/parts/conformer/mhsa_rel_pos.py | 40 ++++++++++++++--------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/i6_models/parts/conformer/mhsa_rel_pos.py b/i6_models/parts/conformer/mhsa_rel_pos.py index 7d2f8cc7..689c8379 100644 --- a/i6_models/parts/conformer/mhsa_rel_pos.py +++ b/i6_models/parts/conformer/mhsa_rel_pos.py @@ -11,8 +11,8 @@ import torch.nn.functional as F from i6_models.config import ModelConfiguration -from i6_models.util import compat from i6_models.parts.dropout import BroadcastDropout +from i6_models.util import compat @dataclass @@ -195,31 +195,41 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to q_with_bias_u = q + self.pos_bias_u if self.with_pos_bias else q # [B, T, #heads, F'] q_with_bias_v = q + self.pos_bias_v if self.with_pos_bias else q - # attention matrix a and c - attn_ac = torch.einsum("bihf, bjhf -> bhij", q_with_bias_u, k) # [B, #heads, T, T'] - # attention matrix b and d attn_bd = torch.einsum( "bihf, ijhf -> bhij", q_with_bias_v, rel_pos_embeddings ) # [B, #heads, T, T'] or [B, #heads, T, T+T'+1] - if not self.learnable_pos_emb: attn_bd = self._rel_shift_bhij(attn_bd, k_len=time_dim_size) # [B, #heads, T, T'] - attn = attn_ac + attn_bd + mask # [B, #heads, T, T'] - attn_scaled = attn * (math.sqrt(1.0 / float(self.embed_dim_per_head))) # [B, #heads, T, T'] + # We use attn_mask to add BD matrix to attention scores. + # + # Inside torch's SDPA the mask is added after regular scaling, so to get correct + # results, we need to apply the scaling here before passing to SDPA. + # + # See for reference: + # https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html + attn_bd_mask = attn_bd + mask + scale = math.sqrt(1.0 / float(self.embed_dim_per_head)) + attn_bd_mask_scaled = attn_bd_mask * scale - # softmax and dropout - attn_output_weights = self.att_weights_dropout(F.softmax(attn_scaled, dim=-1)) # [B, #heads, T, T'] - - # sequence of weighted sums over value sequence v = value_seq.view(batch_dim_size, -1, self.num_heads, self.embed_dim_per_head) # [B, T, H, F'] - attn_output = torch.einsum("bhij, bjhf -> bihf", attn_output_weights, v).reshape( - batch_dim_size, -1, self.embed_dim - ) - output_tensor = self.out_proj(attn_output) + # Use torch's SDPA for efficiency. + # + # The attention matrices a and c are computed inside torch's sdpa. + attn_output = F.scaled_dot_product_attention( + q_with_bias_u.transpose(-3, -2), # [B, #heads, T, F'] + k.transpose(-3, -2), # [B, #heads, T', F'] + v.transpose(-3, -2), # [B, #heads, T, F'] + attn_mask=attn_bd_mask_scaled, # [B, #heads, T, T'] + dropout_p=self.att_weights_dropout.p if self.training else 0.0, + scale=scale, + ) # [B, #heads, T, F'] + attn_output = attn_output.transpose(-3, -2).flatten(-2) # [B, T, F'] + assert attn_output.shape[-1] == self.embed_dim + output_tensor = self.out_proj(attn_output) output_tensor = self.dropout(output_tensor) return output_tensor # [B,T,F] From 91fff5aee0fa30ad51ea90cb8832fb2fac9e9b92 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Mon, 22 Dec 2025 11:44:40 +0100 Subject: [PATCH 2/4] REVERTME: add test to evaluate backend support --- tests/test_conformer_rel_pos_attn_backend.py | 95 ++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 tests/test_conformer_rel_pos_attn_backend.py diff --git a/tests/test_conformer_rel_pos_attn_backend.py b/tests/test_conformer_rel_pos_attn_backend.py new file mode 100644 index 00000000..41cf1b5e --- /dev/null +++ b/tests/test_conformer_rel_pos_attn_backend.py @@ -0,0 +1,95 @@ +from __future__ import annotations +from itertools import product + +import torch +from torch.nn.attention import SDPBackend, sdpa_kernel +import pytest + +from i6_models.parts.conformer.mhsa_rel_pos import ConformerMHSARelPosV1Config, ConformerMHSARelPosV1 + + +def get_model( + input_dim, + with_bias=True, + num_att_heads=8, + att_weights_dropout=0.1, + dropout=0.1, + learnable_pos_emb=True, + with_linear_pos=False, + separate_pos_emb_per_head=False, + rel_pos_clip=16, + with_pos_bias=False, + pos_emb_dropout=0.0, + dropout_broadcast_axes=None, +): + cfg = ConformerMHSARelPosV1Config( + input_dim=input_dim, + num_att_heads=num_att_heads, + with_bias=with_bias, + att_weights_dropout=att_weights_dropout, + dropout=dropout, + learnable_pos_emb=learnable_pos_emb, + with_linear_pos=with_linear_pos, + separate_pos_emb_per_head=separate_pos_emb_per_head, + rel_pos_clip=rel_pos_clip, + with_pos_bias=with_pos_bias, + pos_emb_dropout=pos_emb_dropout, + dropout_broadcast_axes=dropout_broadcast_axes, + ) + return ConformerMHSARelPosV1(cfg) + + +testdata = list( + product( + [True, False], + [True, False], + [0.0, 0.1], + [True, False], + [True, False], + [ + SDPBackend.MATH, + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.CUDNN_ATTENTION, + ], + ) +) + + +def backend_to_str(backend): + for backend_, name in [ + (SDPBackend.MATH, "MATH"), + (SDPBackend.FLASH_ATTENTION, "FLASH_ATTENTION"), + (SDPBackend.EFFICIENT_ATTENTION, "EFFICIENT_ATTENTION"), + (SDPBackend.CUDNN_ATTENTION, "CUDNN_ATTENTION"), + ]: + if backend == backend_: + return name + + return None + + +@pytest.mark.parametrize( + "learnable_pos_emb, with_pos_bias, pos_emb_dropout, with_linear_pos, separate_pos_emb_per_head, backend", + testdata, + ids=backend_to_str, +) +def test_fused_attn_backend( + learnable_pos_emb, with_pos_bias, pos_emb_dropout, with_linear_pos, separate_pos_emb_per_head, backend +): + input_shape = [4, 15, 32] # B,T,F + seq_len = [15, 12, 10, 15] + + input_tensor = torch.randn(input_shape) + sequence_mask = torch.less(torch.arange(input_shape[1])[None, :], torch.tensor(seq_len)[:, None]) + + model = get_model( + input_dim=32, + learnable_pos_emb=learnable_pos_emb, + with_pos_bias=with_pos_bias, + pos_emb_dropout=pos_emb_dropout, + with_linear_pos=with_linear_pos, + separate_pos_emb_per_head=separate_pos_emb_per_head, + ) + with sdpa_kernel(backend): + _outputs = model(input_tensor, sequence_mask) From b4e6ffd7a23a79aa6d211a4e65ff0b2e5c68ebf2 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Mon, 22 Dec 2025 06:19:15 -0500 Subject: [PATCH 3/4] Add proper bfloat16 support --- i6_models/parts/conformer/mhsa_rel_pos.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/i6_models/parts/conformer/mhsa_rel_pos.py b/i6_models/parts/conformer/mhsa_rel_pos.py index 689c8379..6f05dc26 100644 --- a/i6_models/parts/conformer/mhsa_rel_pos.py +++ b/i6_models/parts/conformer/mhsa_rel_pos.py @@ -173,10 +173,16 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to rel_pos_embeddings = self.rel_pos_embeddings[final_mat] # [T, T', pos_emb_dim] else: - rel_pos_embeddings = self._sinusoidal_pe( - torch.arange(time_dim_size - 1, -time_dim_size, -1, device=input_tensor.device, dtype=torch.float32), - self.pos_emb_dim, - ).view(1, 2 * time_dim_size - 1, self.pos_emb_dim) # [1, T+T'-1, pos_emb_dim] + rel_pos_embeddings = ( + self._sinusoidal_pe( + torch.arange( + time_dim_size - 1, -time_dim_size, -1, device=input_tensor.device, dtype=torch.float32 + ), + self.pos_emb_dim, + ) + .to(input_tensor.dtype) + .view(1, 2 * time_dim_size - 1, self.pos_emb_dim) + ) # [1, T+T'-1, pos_emb_dim] # dropout relative positional embeddings rel_pos_embeddings = self.pos_emb_dropout( @@ -197,7 +203,9 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to # attention matrix b and d attn_bd = torch.einsum( - "bihf, ijhf -> bhij", q_with_bias_v, rel_pos_embeddings + "bihf, ijhf -> bhij", + q_with_bias_v, + rel_pos_embeddings.to(device=q_with_bias_v.device, dtype=q_with_bias_v.dtype), ) # [B, #heads, T, T'] or [B, #heads, T, T+T'+1] if not self.learnable_pos_emb: attn_bd = self._rel_shift_bhij(attn_bd, k_len=time_dim_size) # [B, #heads, T, T'] From c2783ec291b12cfa496ed1301615a124c4c19570 Mon Sep 17 00:00:00 2001 From: Moritz Gunz Date: Mon, 22 Dec 2025 06:19:37 -0500 Subject: [PATCH 4/4] Revert "REVERTME: add test to evaluate backend support" This reverts commit 91fff5aee0fa30ad51ea90cb8832fb2fac9e9b92. --- tests/test_conformer_rel_pos_attn_backend.py | 95 -------------------- 1 file changed, 95 deletions(-) delete mode 100644 tests/test_conformer_rel_pos_attn_backend.py diff --git a/tests/test_conformer_rel_pos_attn_backend.py b/tests/test_conformer_rel_pos_attn_backend.py deleted file mode 100644 index 41cf1b5e..00000000 --- a/tests/test_conformer_rel_pos_attn_backend.py +++ /dev/null @@ -1,95 +0,0 @@ -from __future__ import annotations -from itertools import product - -import torch -from torch.nn.attention import SDPBackend, sdpa_kernel -import pytest - -from i6_models.parts.conformer.mhsa_rel_pos import ConformerMHSARelPosV1Config, ConformerMHSARelPosV1 - - -def get_model( - input_dim, - with_bias=True, - num_att_heads=8, - att_weights_dropout=0.1, - dropout=0.1, - learnable_pos_emb=True, - with_linear_pos=False, - separate_pos_emb_per_head=False, - rel_pos_clip=16, - with_pos_bias=False, - pos_emb_dropout=0.0, - dropout_broadcast_axes=None, -): - cfg = ConformerMHSARelPosV1Config( - input_dim=input_dim, - num_att_heads=num_att_heads, - with_bias=with_bias, - att_weights_dropout=att_weights_dropout, - dropout=dropout, - learnable_pos_emb=learnable_pos_emb, - with_linear_pos=with_linear_pos, - separate_pos_emb_per_head=separate_pos_emb_per_head, - rel_pos_clip=rel_pos_clip, - with_pos_bias=with_pos_bias, - pos_emb_dropout=pos_emb_dropout, - dropout_broadcast_axes=dropout_broadcast_axes, - ) - return ConformerMHSARelPosV1(cfg) - - -testdata = list( - product( - [True, False], - [True, False], - [0.0, 0.1], - [True, False], - [True, False], - [ - SDPBackend.MATH, - SDPBackend.FLASH_ATTENTION, - SDPBackend.EFFICIENT_ATTENTION, - SDPBackend.CUDNN_ATTENTION, - ], - ) -) - - -def backend_to_str(backend): - for backend_, name in [ - (SDPBackend.MATH, "MATH"), - (SDPBackend.FLASH_ATTENTION, "FLASH_ATTENTION"), - (SDPBackend.EFFICIENT_ATTENTION, "EFFICIENT_ATTENTION"), - (SDPBackend.CUDNN_ATTENTION, "CUDNN_ATTENTION"), - ]: - if backend == backend_: - return name - - return None - - -@pytest.mark.parametrize( - "learnable_pos_emb, with_pos_bias, pos_emb_dropout, with_linear_pos, separate_pos_emb_per_head, backend", - testdata, - ids=backend_to_str, -) -def test_fused_attn_backend( - learnable_pos_emb, with_pos_bias, pos_emb_dropout, with_linear_pos, separate_pos_emb_per_head, backend -): - input_shape = [4, 15, 32] # B,T,F - seq_len = [15, 12, 10, 15] - - input_tensor = torch.randn(input_shape) - sequence_mask = torch.less(torch.arange(input_shape[1])[None, :], torch.tensor(seq_len)[:, None]) - - model = get_model( - input_dim=32, - learnable_pos_emb=learnable_pos_emb, - with_pos_bias=with_pos_bias, - pos_emb_dropout=pos_emb_dropout, - with_linear_pos=with_linear_pos, - separate_pos_emb_per_head=separate_pos_emb_per_head, - ) - with sdpa_kernel(backend): - _outputs = model(input_tensor, sequence_mask)