-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a benchmark for portions of LitGPT model other than SDPA #148
Changes from 18 commits
552f415
79b136b
3427326
9d02ea7
26493fa
e87c95d
8b51449
f08cf10
ba9f360
2706b69
d189fda
5f70895
1b73e3c
995229b
ce227ad
101fae4
2cbc8a4
e61d501
377f675
9cb3ca0
2799f11
e351685
fe64f0e
57cec8a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,242 @@ | ||
from dataclasses import dataclass | ||
from functools import wraps | ||
from itertools import groupby, product | ||
from collections.abc import Callable, Sequence | ||
|
||
import pytest | ||
|
||
import torch | ||
|
||
from litgpt import Config, GPT | ||
from litgpt.config import configs | ||
|
||
import thunder | ||
|
||
from thunder.benchmarks import ( | ||
inductor_cutlass_executor, | ||
inductor_triton_executor, | ||
thunder_executor, | ||
torch_compile_executor, | ||
torch_executor, | ||
) | ||
from thunder.common import CompileData, CompileStats | ||
from thunder.core.compile_data import set_compile_data_and_stats | ||
|
||
from thunder.core.jit_ext import thunder_general_jit | ||
from thunder.core.langctxs import set_langctx | ||
from thunder.core.trace import TraceCtx | ||
from thunder.core.transforms import eval_trace | ||
from thunder.executors.torch_compile import to_torch_translator | ||
|
||
from thunder.tests.make_tensor import make_tensor | ||
|
||
BATCH_SIZE = 2 | ||
CONFIG_NAMES = list(sorted(c["name"] for c in configs)) | ||
# CONFIG_NAMES = ["Llama-2-7b-hf",] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Uncommenting this would force generating benchmark cases just for this Llama 2 7B config. |
||
|
||
# There are many configurations but only the following parameters affect the Linear layers in the model: | ||
# - n_embd | ||
# - padded_vocab_size | ||
# - n_head | ||
# - n_query_groups | ||
# - head_size | ||
# - bias | ||
# - intermediate_size | ||
# - n_expert | ||
# Let's select only the configurations that differ in these parameters | ||
unique_config_names = {} | ||
for config_name in CONFIG_NAMES: | ||
config = Config.from_name(config_name) | ||
key = tuple( | ||
getattr(config, k) | ||
for k in ( | ||
"n_embd", | ||
"padded_vocab_size", | ||
"n_head", | ||
"n_query_groups", | ||
"head_size", | ||
"bias", | ||
"intermediate_size", | ||
"n_expert", | ||
) | ||
) | ||
unique_config_names[key] = config.name | ||
|
||
CONFIG_NAMES = list(sorted(unique_config_names.values())) | ||
|
||
# We will skip the Mixtral MoE configs because they are not supported by the current implementation | ||
# See https://github.com/Lightning-AI/lightning-thunder/issues/124 | ||
CONFIG_NAMES = [name for name in CONFIG_NAMES if "mixtral" not in name.lower()] | ||
|
||
|
||
def make_torch_traces_for_config(name: str): | ||
config = Config.from_name(name) | ||
# Two layers is enough to expose the fusing opportunity of the following network boundaries: | ||
# - Embedding layer -> 0th Transformer layer | ||
# - Last Transformer layer -> Output layer | ||
# - End of the Transformer layer -> Beginning of the Transformer layer | ||
config.n_layer = 2 | ||
|
||
model = GPT(config).to(dtype=torch.bfloat16, device="cuda") | ||
input_shape = (BATCH_SIZE, config.block_size) | ||
x = torch.randint(0, config.vocab_size, input_shape, dtype=torch.int64, device="cuda") | ||
|
||
# Acquire the initial trace | ||
# We could use thunder.jit here, but we want to not execute the compiled function | ||
# and instead only get the initial trace before any transformations | ||
# jit_model = thunder.jit(model) | ||
# out = jit_model(x) | ||
# trace = thunder.last_traces(jit_model)[0] | ||
|
||
# We need to set up contexts that are usually set up by the thunder.jit decorator | ||
cd = CompileData(fn=model, disable_preprocessing=True, executor_lookasides={}) | ||
cs = CompileStats() | ||
set_langctx(thunder.torch.torchctx) | ||
set_compile_data_and_stats(cd, cs) | ||
thunder._cache_info_ctx.set({}) | ||
prologue, trace, epilogue = thunder_general_jit( | ||
model, (x,), {}, sharp_edges=thunder.core.options.SHARP_EDGES_OPTIONS.ALLOW | ||
) | ||
|
||
# Remove subsymbols for readability of the trace | ||
for bsym in trace.bound_symbols: | ||
bsym.subsymbols = [] | ||
|
||
producers, consumers = thunder.core.utils.producers_and_consumers(trace) | ||
|
||
# Remove unpacking prims so that they can be identified as inputs of the first chunk | ||
trace.bound_symbols = [bsym for bsym in trace.bound_symbols if bsym.sym.id != thunder.prims.unpack_trivial.id] | ||
|
||
# Remove return prim so that it can be identified as the output of the last chunk | ||
assert trace.bound_symbols.pop().sym.id == thunder.prims.python_return.id | ||
|
||
# We want to split the trace into chunks of network between the scaled dot-product attention calls | ||
assert ( | ||
len([bsym for bsym in trace.bound_symbols if bsym.sym.id == thunder.torch.scaled_dot_product_attention.id]) | ||
== config.n_layer | ||
) | ||
|
||
# This is going to be our delimiter for splitting the trace into chunks | ||
thunder_sdpa = thunder.torch.scaled_dot_product_attention | ||
chunks = list(list(g) for k, g in groupby(trace.bound_symbols, key=lambda x: x.sym.id == thunder_sdpa.id) if not k) | ||
|
||
# Now we need to convert the chunks into a list of functions | ||
regions = [thunder.executors.utils.Region(producers, consumers, chunk) for chunk in chunks] | ||
|
||
# After this point, we will have a list of regions that are represented as regular PyTorch functions | ||
# We can acquire the Python functions by calling .python_callable() on each "torch_trace" object | ||
torch_traces = [] | ||
for r in regions: | ||
# Here we construct a trace that will be used to compile the function | ||
region_trace = TraceCtx(None) | ||
region_trace.bound_symbols = list(r.bound_symbols) | ||
sorted_proxy_inputs = [v.proxy for v in sorted(r.inputs, key=lambda x: x.proxy.name)] | ||
sorted_proxy_outputs = [v.proxy for v in sorted(r.outputs, key=lambda x: x.proxy.name)] | ||
region_trace.args = sorted_proxy_inputs | ||
region_trace.kwargs = {} | ||
region_trace.bound_symbols.append(thunder.prims.python_return.bind(sorted_proxy_outputs, output=())) | ||
region_trace = thunder.executors.passes.dce(region_trace) | ||
|
||
def torch_interpreted_func(*args): | ||
return eval_trace(region_trace, *args, symbol_mapper=to_torch_translator) | ||
|
||
torch_trace = thunder.trace(inline_trace=False)(torch_interpreted_func, *sorted_proxy_inputs) | ||
|
||
# Remove subsymbols for readability of the trace | ||
for bsym in torch_trace.bound_symbols: | ||
bsym.subsymbols = [] | ||
|
||
torch_traces.append(torch_trace) | ||
|
||
return torch_traces | ||
|
||
|
||
def wrap_for_benchmark(fn): | ||
@wraps(fn) | ||
def fn_(*args, **kwargs): | ||
torch.cuda.synchronize() | ||
result = fn(*args, **kwargs) | ||
torch.cuda.synchronize() | ||
return result | ||
|
||
return fn_ | ||
|
||
|
||
def forward_and_backward(torch_trace: TraceCtx, jit_fn: Callable): | ||
fn = torch_trace.python_callable(include_no_grad=False) | ||
jfn = jit_fn(fn) | ||
|
||
@wraps(jfn) | ||
def wrapper(*args, **kwargs): | ||
result = jfn(*args, **kwargs) | ||
if isinstance(result, Sequence): | ||
torch.autograd.backward(result, [torch.ones_like(x) for x in result]) | ||
else: | ||
result.backward(torch.ones_like(result)) | ||
return result | ||
|
||
return wrapper | ||
|
||
|
||
executor_names = { | ||
torch_executor: "eager", | ||
thunder_executor: "thunder", | ||
# Inductor with the default GEMM backend (Aten) | ||
torch_compile_executor: "inductor", | ||
# Inductor with autotuning between Aten and CUTLASS or Triton as the GEMM backend | ||
inductor_triton_executor: "inductor+triton_gemm", | ||
# At the moment, I was not able to get the CUTLASS backend to work with the Inductor executor | ||
# torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised: | ||
# LoweringException: ErrorFromChoice: Error in function: cuda_cutlass_gemm | ||
# inductor_cutlass_executor: "inductor+cutlass_gemm", | ||
} | ||
|
||
|
||
@dataclass | ||
class TraceInfo: | ||
config_name: str | ||
region_idx: int | ||
trace: TraceCtx | ||
|
||
|
||
# litgpt_traces = [ | ||
# TraceInfo(name, i, trace) for name in CONFIG_NAMES for i, trace in enumerate(make_torch_traces_for_config(name)) | ||
# ] | ||
Comment on lines
+206
to
+208
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this still needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. List comprehensions are easier to read for me than the for-loop below. |
||
|
||
# Rewrite list comprehension above to include printing progress bar | ||
litgpt_traces = [] | ||
for j, name in enumerate(CONFIG_NAMES): | ||
print(f"Constructing benchmark cases for config: {name} ({j + 1}/{len(CONFIG_NAMES)})") | ||
traces = make_torch_traces_for_config(name) | ||
for i, trace in enumerate(traces): | ||
litgpt_traces.append(TraceInfo(name, i, trace)) | ||
|
||
# Now we have a list of torch_traces that are ready to be benchmarked | ||
trace_executor_pairs = list(product(litgpt_traces, (executor_names.keys()))) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"info, executor", | ||
trace_executor_pairs, | ||
ids=[ | ||
f"{info.config_name}_region{info.region_idx}_{executor_names[executor]}" | ||
for info, executor in trace_executor_pairs | ||
], | ||
) | ||
def test_litgpt(benchmark, info, executor): | ||
torch_trace = info.trace | ||
|
||
def setup(): | ||
args = [] | ||
for a in torch_trace.args: | ||
torch_dtype = thunder.torch.to_torch_dtype(a.dtype) | ||
torch_device = thunder.core.devices.to_torch_device(a.device) | ||
is_float = isinstance(a.dtype, thunder.core.dtypes.floating) | ||
low = 0 if not is_float else None | ||
args.append(make_tensor(a.shape, dtype=torch_dtype, device=torch_device, requires_grad=is_float, low=low)) | ||
return args, {} | ||
|
||
fn = forward_and_backward(torch_trace, executor) | ||
fn = wrap_for_benchmark(fn) | ||
|
||
benchmark.pedantic(fn, setup=setup, rounds=20, warmup_rounds=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.