Replies: 1 comment
-
|
@ymcki Hi, simply upgrade your FLA, this warning has been removed |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I am trying to replicate the results of Jet-Nemotron-2B that depends on the fla package. When I simply run the sample code at the huggingface page, I got this warning and I am not seeing the performance claimed by this model:
/home/user/anaconda3/lib/python3.12/site-packages/fla/ops/gated_delta_rule/fused_recurrent.py:292: UserWarning: Input tensor shape suggests potential format mismatch: seq_len (1) < num_heads (12). This may indicate the inputs were passed in head-first format [B, H, T, ...] when head_first=False was specified. Please verify your input tensor format matches the expected shape [B, T, H, ...].
How do I format my prompt such that this warning can be suppressed? Do I need to pass >=12 prompts in a batch to avoid this warning? If not, how?
Beta Was this translation helpful? Give feedback.
All reactions