Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Different sequence numbers calculate inconsistent results #696

Open
sitabulaixizawaluduo opened this issue Dec 24, 2024 · 10 comments
Open
Assignees

Comments

@sitabulaixizawaluduo
Copy link

version: 0.1.6
qusestion: Calculate prompt1 separately and prompt1 together with prompt2, and the attention output values obtained are different. Is this normal? What is the specific reason? Thank you.

@yzh119
Copy link
Collaborator

yzh119 commented Dec 24, 2024

How large is the discrepancy?

If there is only tiny difference, I think it's normal because internally we use dynamic split-k for maximizing hardware utilization, and the number of splits will change according to the batch statistics (query/kv length of each request): there are more queries in prompt1 + promopt2 over just prompt1, and our scheduler would split kv into more chunks if there is only prompt1 than prompt1+prompt2. But mathematically they should be equivalent, the output difference should comes from floating point errors.

@sitabulaixizawaluduo
Copy link
Author

How large is the discrepancy?

If there is only tiny difference, I think it's normal because internally we use dynamic split-k for maximizing hardware utilization, and the number of splits will change according to the batch statistics (query/kv length of each request): there are more queries in prompt1 + promopt2 over just prompt1, and our scheduler would split kv into more chunks if there is only prompt1 than prompt1+prompt2. But mathematically they should be equivalent, the output difference should comes from floating point errors.

I used the same batch of data for evaluation, using GPT-4 to score, and the score of single inference and batch processing differ by 5-10 points. Which part of the calculation is "floating point errors"?

@yzh119
Copy link
Collaborator

yzh119 commented Dec 24, 2024

and the score of single inference and batch processing differ by 5-10 points

I suppose it's also related to serving engine implementations and random seeds (especially when you use top-p/k sampling), not only attention implementations. Also, GEMM implementation's split-k choice would differ when you change the number of requests.

One of the simplest ways to eliminate sampling randomness is to use greedy sampling.

Which part of the calculation is "floating point errors"?

Floating point accumulation is not associative/commutative (a+b+c != a+c+b, (a+b)+c != a+(b+c)).

@sitabulaixizawaluduo
Copy link
Author

I suppose it's also related to serving engine implementations and random seeds (especially when you use top-p/k sampling), not only attention implementations.

In order to eliminate the influence of randomness, the strategy of greedy sampling has been adopted.

@yzh119
Copy link
Collaborator

yzh119 commented Dec 24, 2024

Can you compare the layerwise features of first generated step for both settings? That can help identify the issue.

@sitabulaixizawaluduo
Copy link
Author

Can you compare the layerwise features of first generated step for both settings? That can help identify the issue.

I printed the first layer last_hidden_states for each request,
image
image
The first graph is two request batches, the second graph is three request batches, the first two are consistent, q represents query after rope

@yzh119
Copy link
Collaborator

yzh119 commented Dec 25, 2024

It's very likely that different number of chunks in split-k result in the difference. Currently flashinfer uses data type of Q as the intermediate partial attention output data type in split-k, and bfloat16 is more error-prone than float16 because of less bits on mantissa (7 vs 11). For DTypeQ = bfloat16, it might be more accurate to use float32 as intermediate partial attention output, I can support this soon.

I have several other details to confirm:

  1. what's the KV-Cache length per request in your case, we will only activate split-k when sequence length is greater than 128, so if use short sequence length, it might not because of split-k numerical issue.
  2. can you try converting the model to fp16 (if possible) and run the scripts again? fp16 should be more accurate as split-k intermediate data type.

@yzh119 yzh119 self-assigned this Dec 25, 2024
@sitabulaixizawaluduo
Copy link
Author

It's very likely that different number of chunks in split-k result in the difference. Currently flashinfer uses data type of Q as the intermediate partial attention output data type in split-k, and bfloat16 is more error-prone than float16 because of less bits on mantissa (7 vs 11). For DTypeQ = bfloat16, it might be more accurate to use float32 as intermediate partial attention output, I can support this soon.

I have several other details to confirm:

  1. what's the KV-Cache length per request in your case, we will only activate split-k when sequence length is greater than 128, so if use short sequence length, it might not because of split-k numerical issue.
  2. can you try converting the model to fp16 (if possible) and run the scripts again? fp16 should be more accurate as split-k intermediate data type.

I made two changes, one is to turn off the cuda graph to avoid using split kv in the decode stage, the other is to choose fa2 as the backend when prefilling the calculation, and modify the const uint32_t min_kv_chunk_size = std :: max ((4096/page_size), 1U); avoid prefilling using split kv (my input is less than 4096) But things are going to be different now.

@yzh119
Copy link
Collaborator

yzh119 commented Dec 25, 2024

turn off the cuda graph to avoid using split kv in the decode stage

Even when we turned off cudagraph, spliit-kv will be activated when number of requests is small.

But things are going to be different now.

Can you explain more?

btw, v0.1.6's FusedAddRMSNorm has some numerical issues when data type is bfloat16 (it was fixed in #587), can you try later versions?

@yzh119
Copy link
Collaborator

yzh119 commented Dec 30, 2024

@sitabulaixizawaluduo we plan to use fp32 as the split-k output data type for bf16 inputs in v0.2.1, this should resolve the discrepancy here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants