File tree Expand file tree Collapse file tree 2 files changed +10
-0
lines changed
src/transformers/models/bamba Expand file tree Collapse file tree 2 files changed +10
-0
lines changed Original file line number Diff line number Diff line change @@ -1302,6 +1302,11 @@ def forward(
1302
1302
all_hidden_states += (hidden_states ,)
1303
1303
1304
1304
if self .gradient_checkpointing and self .training :
1305
+ if "cu_seq_lens_q" in "flash_attn_kwargs" :
1306
+ raise NotImplementedError (
1307
+ "Padding-free training with FlashAttentionKwargs and gradient checkpointing"
1308
+ " not currently supported."
1309
+ )
1305
1310
layer_outputs = self ._gradient_checkpointing_func (
1306
1311
decoder_layer .__call__ ,
1307
1312
hidden_states ,
Original file line number Diff line number Diff line change @@ -1050,6 +1050,11 @@ def forward(
1050
1050
all_hidden_states += (hidden_states ,)
1051
1051
1052
1052
if self .gradient_checkpointing and self .training :
1053
+ if "cu_seq_lens_q" in "flash_attn_kwargs" :
1054
+ raise NotImplementedError (
1055
+ "Padding-free training with FlashAttentionKwargs and gradient checkpointing"
1056
+ " not currently supported."
1057
+ )
1053
1058
layer_outputs = self ._gradient_checkpointing_func (
1054
1059
decoder_layer .__call__ ,
1055
1060
hidden_states ,
You can’t perform that action at this time.
0 commit comments