Skip to content

Commit

Permalink
fixing rope pytest benchmark grad accumulation (#3743)
Browse files Browse the repository at this point in the history
#3349 removed grad accumulation, but rope benchmark implementation needs
an update to get that working.

Reference implementation.
```
           Model  Batch-Size  Sequence-Length  ... Forward-Time(ms)  Backward-Kernels  Backward-Time(ms)
0  Llama-2-7b-hf                      2             4096  ...            0.166                 5              0.857
0  Llama-3-8B                         2             8192  ...            0.567                 5              1.433
0  mistralai/Mistral-Nemo-Base-2407   1             4096  ...            0.138                 6              0.166
0  Qwen/Qwen2.5-7B-Instruct           1             4096  ...            0.072                 8              0.397
0  microsoft/Phi-3.5-mini-instruct    1             8192  ...            0.236                 6              0.494
```
after l2_cache clear
```
                             Model  Batch-Size  Sequence-Length  ... Forward-Time(ms)  Backward-Kernels  Backward-Time(ms)
0  Llama-2-7b-hf                      2             4096  ...            0.166                 5              0.870
0  Llama-3-8B                         2             8192  ...            0.567                 5              1.444
0  mistralai/Mistral-Nemo-Base-2407   1             4096  ...            0.138                 6              0.192
0  Qwen/Qwen2.5-7B-Instruct           1             4096  ...            0.072                 8              0.417
0  microsoft/Phi-3.5-mini-instruct    1             8192  ...            0.234                 6              0.516
```

Before this PR:
```
Name (time in us)                                                                       Mean                    Median
---------------------------------------------------------------------------------------------------------------------------------
test_rope_bwd_benchmark[executor='thunder'-variation='llama_2_7b_hf_rope']        1,192.8558 (14.56)        1,191.9040 (14.53)
test_rope_bwd_benchmark[executor='thunder'-variation='llama_3_8B_rope']           1,767.5348 (21.58)        1,766.8410 (21.54)
test_rope_bwd_benchmark[executor='thunder'-variation='hf_mistral_nemo_rope']        275.4680 (3.36)           275.7265 (3.36)
test_rope_bwd_benchmark[executor='thunder'-variation='hf_qwen2_rope']               488.4243 (5.96)           488.3105 (5.95)
test_rope_bwd_benchmark[executor='thunder'-variation='hf_phi3_rope']                757.9140 (9.25)           757.6910 (9.24)
---------------------------------------------------------------------------------------------------------------------------------
```

In this PR:
```
Name (time in us)                                                                    Mean                Median
-----------------------------------------------------------------------------------------------------------------------
test_rope_bwd_benchmark[executor='thunder'-variation='llama_2_7b_hf_rope']       871.5996 (5.23)       871.6050 (5.24)
test_rope_bwd_benchmark[executor='thunder'-variation='llama_3_8B_rope']        1,443.0095 (8.66)     1,442.9955 (8.67)
test_rope_bwd_benchmark[executor='thunder'-variation='hf_mistral_nemo_rope']     166.5515 (1.0)        166.4480 (1.0)
test_rope_bwd_benchmark[executor='thunder'-variation='hf_qwen2_rope']            386.4463 (2.32)       386.5565 (2.32)
test_rope_bwd_benchmark[executor='thunder'-variation='hf_phi3_rope']             452.3351 (2.72)       452.0685 (2.72)
-----------------------------------------------------------------------------------------------------------------------
```

With the existing issue on pytest/torch.profiler, if I instead run each
benchmark separately,
```
test_rope_bwd_benchmark[executor='thunder'-variation='llama_2_7b_hf_rope']     871.1912  871.2465
test_rope_bwd_benchmark[executor='thunder'-variation='llama_3_8B_rope']        1.4427  1.4427
test_rope_bwd_benchmark[executor='thunder'-variation='hf_mistral_nemo_rope']   191.6567  191.6795
test_rope_bwd_benchmark[executor='thunder'-variation='hf_qwen2_rope']          416.8007  416.8935
test_rope_bwd_benchmark[executor='thunder'-variation='hf_phi3_rope']           514.7512  514.4900
```

So these number does match the manual benchmark with l2_cache cleared. I
think that justifies this PR.
  • Loading branch information
jjsjann123 authored Jan 24, 2025
1 parent 6b0b17e commit 8689c33
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions benchmarks/python/test_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@ def test_rope_fwd_benchmark(
elif executor == "torchcompile":
clear_dynamo_cache()

model, inputs, _, _ = rope_setup[variation]()
model, gen_inputs, _, _ = rope_setup[variation]()
inputs = gen_inputs()

def fwd_call(inp):
return model(*inp)

# Compile the fwd fn for torchcompile
benchmark_fn = with_executor(executor, fwd_call, **kwargs)
run_benchmark(benchmark, benchmark_fn, inputs())
run_benchmark(benchmark, benchmark_fn, inputs)


@pytest.mark.parametrize(
Expand All @@ -65,14 +66,15 @@ def test_rope_bwd_benchmark(
elif executor == "torchcompile":
clear_dynamo_cache()

model, fwd_inputs, grad, iobytes = rope_setup[variation]()
model, gen_inputs, grad, iobytes = rope_setup[variation]()
fwd_inputs = gen_inputs()

def fwd_call(inp):
return model(*inp)

# execute the compiled fwd fn
fwd_fn = with_executor(executor, fwd_call, **kwargs)
outputs = fwd_fn(fwd_inputs())
outputs = fwd_fn(fwd_inputs)

# accumulate all output, so we can feed a single grad and use the unary bwd function
output = outputs[0]
Expand All @@ -82,5 +84,5 @@ def fwd_call(inp):
# NOTE: the iobytes is computed based on how thunder autograd worked. So this is just
# a reference point for torchcompile and eager executor for comparison.
run_benchmark(
benchmark, unary_bwd_torch, [output, grad(), fwd_inputs()], iobytes=iobytes()
benchmark, unary_bwd_torch, [output, grad(), *fwd_inputs], iobytes=iobytes()
)

0 comments on commit 8689c33

Please sign in to comment.