Skip to content

Commit 9577fb4

Browse files
committed
raise error on FlashAttnKwargs + grad ckpt
1 parent 9f59f93 commit 9577fb4

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

src/transformers/models/bamba/modeling_bamba.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,6 +1302,11 @@ def forward(
13021302
all_hidden_states += (hidden_states,)
13031303

13041304
if self.gradient_checkpointing and self.training:
1305+
if "cu_seq_lens_q" in "flash_attn_kwargs":
1306+
raise NotImplementedError(
1307+
"Padding-free training with FlashAttentionKwargs and gradient checkpointing"
1308+
" not currently supported."
1309+
)
13051310
layer_outputs = self._gradient_checkpointing_func(
13061311
decoder_layer.__call__,
13071312
hidden_states,

src/transformers/models/bamba/modular_bamba.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,6 +1050,11 @@ def forward(
10501050
all_hidden_states += (hidden_states,)
10511051

10521052
if self.gradient_checkpointing and self.training:
1053+
if "cu_seq_lens_q" in "flash_attn_kwargs":
1054+
raise NotImplementedError(
1055+
"Padding-free training with FlashAttentionKwargs and gradient checkpointing"
1056+
" not currently supported."
1057+
)
10531058
layer_outputs = self._gradient_checkpointing_func(
10541059
decoder_layer.__call__,
10551060
hidden_states,

0 commit comments

Comments
 (0)