Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fixing rope pytest benchmark grad accumulation (#3743)
#3349 removed grad accumulation, but rope benchmark implementation needs an update to get that working. Reference implementation. ``` Model Batch-Size Sequence-Length ... Forward-Time(ms) Backward-Kernels Backward-Time(ms) 0 Llama-2-7b-hf 2 4096 ... 0.166 5 0.857 0 Llama-3-8B 2 8192 ... 0.567 5 1.433 0 mistralai/Mistral-Nemo-Base-2407 1 4096 ... 0.138 6 0.166 0 Qwen/Qwen2.5-7B-Instruct 1 4096 ... 0.072 8 0.397 0 microsoft/Phi-3.5-mini-instruct 1 8192 ... 0.236 6 0.494 ``` after l2_cache clear ``` Model Batch-Size Sequence-Length ... Forward-Time(ms) Backward-Kernels Backward-Time(ms) 0 Llama-2-7b-hf 2 4096 ... 0.166 5 0.870 0 Llama-3-8B 2 8192 ... 0.567 5 1.444 0 mistralai/Mistral-Nemo-Base-2407 1 4096 ... 0.138 6 0.192 0 Qwen/Qwen2.5-7B-Instruct 1 4096 ... 0.072 8 0.417 0 microsoft/Phi-3.5-mini-instruct 1 8192 ... 0.234 6 0.516 ``` Before this PR: ``` Name (time in us) Mean Median --------------------------------------------------------------------------------------------------------------------------------- test_rope_bwd_benchmark[executor='thunder'-variation='llama_2_7b_hf_rope'] 1,192.8558 (14.56) 1,191.9040 (14.53) test_rope_bwd_benchmark[executor='thunder'-variation='llama_3_8B_rope'] 1,767.5348 (21.58) 1,766.8410 (21.54) test_rope_bwd_benchmark[executor='thunder'-variation='hf_mistral_nemo_rope'] 275.4680 (3.36) 275.7265 (3.36) test_rope_bwd_benchmark[executor='thunder'-variation='hf_qwen2_rope'] 488.4243 (5.96) 488.3105 (5.95) test_rope_bwd_benchmark[executor='thunder'-variation='hf_phi3_rope'] 757.9140 (9.25) 757.6910 (9.24) --------------------------------------------------------------------------------------------------------------------------------- ``` In this PR: ``` Name (time in us) Mean Median ----------------------------------------------------------------------------------------------------------------------- test_rope_bwd_benchmark[executor='thunder'-variation='llama_2_7b_hf_rope'] 871.5996 (5.23) 871.6050 (5.24) test_rope_bwd_benchmark[executor='thunder'-variation='llama_3_8B_rope'] 1,443.0095 (8.66) 1,442.9955 (8.67) test_rope_bwd_benchmark[executor='thunder'-variation='hf_mistral_nemo_rope'] 166.5515 (1.0) 166.4480 (1.0) test_rope_bwd_benchmark[executor='thunder'-variation='hf_qwen2_rope'] 386.4463 (2.32) 386.5565 (2.32) test_rope_bwd_benchmark[executor='thunder'-variation='hf_phi3_rope'] 452.3351 (2.72) 452.0685 (2.72) ----------------------------------------------------------------------------------------------------------------------- ``` With the existing issue on pytest/torch.profiler, if I instead run each benchmark separately, ``` test_rope_bwd_benchmark[executor='thunder'-variation='llama_2_7b_hf_rope'] 871.1912 871.2465 test_rope_bwd_benchmark[executor='thunder'-variation='llama_3_8B_rope'] 1.4427 1.4427 test_rope_bwd_benchmark[executor='thunder'-variation='hf_mistral_nemo_rope'] 191.6567 191.6795 test_rope_bwd_benchmark[executor='thunder'-variation='hf_qwen2_rope'] 416.8007 416.8935 test_rope_bwd_benchmark[executor='thunder'-variation='hf_phi3_rope'] 514.7512 514.4900 ``` So these number does match the manual benchmark with l2_cache cleared. I think that justifies this PR.
- Loading branch information