Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Context parallelism understanding #723

Open
jinsong-mao opened this issue Dec 9, 2024 · 4 comments
Open

Context parallelism understanding #723

jinsong-mao opened this issue Dec 9, 2024 · 4 comments
Labels
context_parallel question Further information is requested

Comments

@jinsong-mao
Copy link

Hi

We are recently testing the CP parallelism strategy, for a 2D configuration: FSDP+CP.
From what we know, CP is to slice the sequence length, as attention kernel needs to compute the attention for the whole sequence, which means each GPU needs to gather all the sharded KV cache using some collective communication kernels.

However, we didn't see any such kind of kernels, only found the All-Gather for parameters in pre-forward phase.
image

Is there anything that we misunderstood? please add your comments for better understanding.

Thanks.

@tianyu-l tianyu-l added the question Further information is requested label Dec 9, 2024
@tianyu-l
Copy link
Contributor

tianyu-l commented Dec 9, 2024

cc: @XilunWu @fegin

@fegin
Copy link
Contributor

fegin commented Dec 9, 2024

Screenshot 2024-12-09 at 10 13 23 AM

This is the trace with

CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.data_parallel_replicate_degree=1 --training.data_parallel_shard_degree=2 --experimental.context_parallel_degree=4

The selected allgather is the one issued by CP, which is not the same as FSDP. I'm wondering what command did you use?

@jinsong-mao
Copy link
Author

@fegin Thanks for your feedback.
I was running the experiments on AMD platform, suppose you were using Nvidia GPU.
My command is almost the same as yours, I can only find the stream of ALL-Gather issued by FSDP, and no streams of communication issued by CP.
Because of the difference in hardware, maybe the kernel issued by CP was not executed or has not been captured by profiler.

Not sure the possible reason for why there is no such kind of trace in AMD GPU.

Thanks.

@fegin
Copy link
Contributor

fegin commented Dec 10, 2024

I see. My best guest is that because of the hardware, SDPA doesn't trap into the right kernel to dispatch to the right CP implementation. We currently only support Flash attention and Memory efficient attention kernels. Math attention kernel is not supported. https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
context_parallel question Further information is requested
Projects
None yet
Development

No branches or pull requests

4 participants