Skip to content

[litgpt benchmark] enable force_recompute_fp8_weight_in_bwd when torchao.float8 is used with FSDP2 #1528

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 20, 2025

Conversation

crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Dec 9, 2024

What does this PR do?

As per title.

ref: pytorch/ao@e919558

With 8 H100s and pjnl-20241209.
Used command is: torchrun --nproc-per-node 8 --local-ranks-filter 0 --role rank --tee 3 thunder/benchmarks/benchmark_litgpt.py --model_name <MODEL_NAME> --compile inductor --distributed_mode fsdp2 --shard_mode zero2 --use_torchao_fp8_linear true --use_torchao_fp8_allgather true --use_torchao_fp8_precompute_scale_for_fsdp true

Llama-2-7b-hf.

branch perf (tokens/s/gpu) mem usage (GB)
main 13947.29 34.26
this PR 13995.80 27.69

Llama-3-8B

branch perf (tokens/s/gpu) mem usage (GB)
main 12404.18 58.65
this PR 12414.15 51.67

cc @crcrpar

ref: pytorch/ao@e919558
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
@crcrpar crcrpar force-pushed the crpa/ao-recompute_fp8_weight_in_bwd branch from a824ae4 to 9ad3327 Compare December 17, 2024 14:02
@lantiga lantiga merged commit 77c6a74 into main Jan 20, 2025
41 checks passed
@lantiga lantiga deleted the crpa/ao-recompute_fp8_weight_in_bwd branch January 20, 2025 08:41
riccardofelluga pushed a commit that referenced this pull request Jan 27, 2025
…orchao.float8` is used with FSDP2 (#1528)

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants