-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a benchmark for portions of LitGPT model other than SDPA #148
Closed
Closed
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
552f415
Allow skipping no_grad decorator
IvanYashchuk 79b136b
Add torchctx to thunder.torch
IvanYashchuk 3427326
Add initial version of chunking litgpt network and benchmarking
IvanYashchuk 9d02ea7
Add parametrization over the config name
IvanYashchuk 26493fa
Use TraceInfo dataclass instead of plain tuples
IvanYashchuk e87c95d
batch_size -> BATCH_SIZE
IvanYashchuk 8b51449
Generate cases for all litgpt configs
IvanYashchuk f08cf10
Add region index to TraceInfo
IvanYashchuk ba9f360
Move make_tensor import to the header
IvanYashchuk 2706b69
Skip mixtral
IvanYashchuk d189fda
Print progress
IvanYashchuk 5f70895
Filter out non unique configs
IvanYashchuk 1b73e3c
Add inductor with gemm autotuning
IvanYashchuk 995229b
Use CONFIG_NAME when constructing unique_config_names
IvanYashchuk ce227ad
formatting
IvanYashchuk 101fae4
Skip inductor+cutlass as it's raising an error
IvanYashchuk 2cbc8a4
Rename executors to include gemm in the name
IvanYashchuk e61d501
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 377f675
Merge branch 'main' into litgpt-chunks-bench
t-vi 9cb3ca0
The type of general_jit return was changed in https://github.com/Ligh…
IvanYashchuk 2799f11
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] e351685
Merge remote-tracking branch 'upstream/main' into litgpt-chunks-bench
IvanYashchuk fe64f0e
Remove ATEN from gemm_backend
IvanYashchuk 57cec8a
Split benchmarks into forward and backward; make pytest-benchmark groups
IvanYashchuk File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,286 @@ | ||
from collections.abc import Callable, Sequence | ||
from dataclasses import dataclass | ||
from functools import wraps | ||
from itertools import groupby, product | ||
|
||
import pytest | ||
|
||
import torch | ||
|
||
from litgpt import Config, GPT | ||
from litgpt.config import configs | ||
|
||
import thunder | ||
|
||
from thunder.benchmarks import ( | ||
inductor_cutlass_executor, | ||
inductor_triton_executor, | ||
thunder_executor, | ||
torch_compile_executor, | ||
torch_executor, | ||
) | ||
from thunder.common import CompileData, CompileStats | ||
from thunder.core.compile_data import set_compile_data_and_stats | ||
|
||
from thunder.core.jit_ext import thunder_general_jit | ||
from thunder.core.langctxs import set_langctx | ||
from thunder.core.trace import TraceCtx | ||
from thunder.core.transforms import eval_trace | ||
from thunder.executors.torch_compile import to_torch_translator | ||
|
||
from thunder.tests.make_tensor import make_tensor | ||
|
||
BATCH_SIZE = 2 | ||
CONFIG_NAMES = list(sorted(c["name"] for c in configs)) | ||
# CONFIG_NAMES = ["Llama-2-7b-hf",] | ||
|
||
# There are many configurations but only the following parameters affect the Linear layers in the model: | ||
# - n_embd | ||
# - padded_vocab_size | ||
# - n_head | ||
# - n_query_groups | ||
# - head_size | ||
# - bias | ||
# - intermediate_size | ||
# - n_expert | ||
# Let's select only the configurations that differ in these parameters | ||
unique_config_names = {} | ||
for config_name in CONFIG_NAMES: | ||
config = Config.from_name(config_name) | ||
key = tuple( | ||
getattr(config, k) | ||
for k in ( | ||
"n_embd", | ||
"padded_vocab_size", | ||
"n_head", | ||
"n_query_groups", | ||
"head_size", | ||
"bias", | ||
"intermediate_size", | ||
"n_expert", | ||
) | ||
) | ||
unique_config_names[key] = config.name | ||
|
||
CONFIG_NAMES = list(sorted(unique_config_names.values())) | ||
|
||
# We will skip the Mixtral MoE configs because they are not supported by the current implementation | ||
# See https://github.com/Lightning-AI/lightning-thunder/issues/124 | ||
CONFIG_NAMES = [name for name in CONFIG_NAMES if "mixtral" not in name.lower()] | ||
|
||
|
||
def make_torch_traces_for_config(name: str): | ||
config = Config.from_name(name) | ||
# Two layers is enough to expose the fusing opportunity of the following network boundaries: | ||
# - Embedding layer -> 0th Transformer layer | ||
# - Last Transformer layer -> Output layer | ||
# - End of the Transformer layer -> Beginning of the Transformer layer | ||
config.n_layer = 2 | ||
|
||
model = GPT(config).to(dtype=torch.bfloat16, device="cuda") | ||
input_shape = (BATCH_SIZE, config.block_size) | ||
x = torch.randint(0, config.vocab_size, input_shape, dtype=torch.int64, device="cuda") | ||
|
||
# Acquire the initial trace | ||
# We could use thunder.jit here, but we want to not execute the compiled function | ||
# and instead only get the initial trace before any transformations | ||
# jit_model = thunder.jit(model) | ||
# out = jit_model(x) | ||
# trace = thunder.last_traces(jit_model)[0] | ||
|
||
# We need to set up contexts that are usually set up by the thunder.jit decorator | ||
cd = CompileData(fn=model, disable_preprocessing=True, executor_lookasides={}) | ||
cs = CompileStats() | ||
set_langctx(thunder.torch.torchctx) | ||
set_compile_data_and_stats(cd, cs) | ||
thunder._cache_info_ctx.set({}) | ||
jit_results = thunder_general_jit(model, (x,), {}, sharp_edges=thunder.core.options.SHARP_EDGES_OPTIONS.ALLOW) | ||
prologue = jit_results.prologue_trace | ||
trace = jit_results.computation_trace | ||
epilogue = jit_results.epilogue_trace | ||
|
||
# Remove subsymbols for readability of the trace | ||
for bsym in trace.bound_symbols: | ||
bsym.subsymbols = [] | ||
|
||
producers, consumers = thunder.core.utils.producers_and_consumers(trace) | ||
|
||
# Remove unpacking prims so that they can be identified as inputs of the first chunk | ||
trace.bound_symbols = [bsym for bsym in trace.bound_symbols if bsym.sym.id != thunder.prims.unpack_trivial.id] | ||
|
||
# Remove return prim so that it can be identified as the output of the last chunk | ||
assert trace.bound_symbols.pop().sym.id == thunder.prims.python_return.id | ||
|
||
# We want to split the trace into chunks of network between the scaled dot-product attention calls | ||
assert ( | ||
len([bsym for bsym in trace.bound_symbols if bsym.sym.id == thunder.torch.scaled_dot_product_attention.id]) | ||
== config.n_layer | ||
) | ||
|
||
# This is going to be our delimiter for splitting the trace into chunks | ||
thunder_sdpa = thunder.torch.scaled_dot_product_attention | ||
chunks = list(list(g) for k, g in groupby(trace.bound_symbols, key=lambda x: x.sym.id == thunder_sdpa.id) if not k) | ||
|
||
# Now we need to convert the chunks into a list of functions | ||
regions = [thunder.executors.utils.Region(producers, consumers, chunk) for chunk in chunks] | ||
|
||
# After this point, we will have a list of regions that are represented as regular PyTorch functions | ||
# We can acquire the Python functions by calling .python_callable() on each "torch_trace" object | ||
torch_traces = [] | ||
for r in regions: | ||
# Here we construct a trace that will be used to compile the function | ||
region_trace = TraceCtx(None) | ||
region_trace.bound_symbols = list(r.bound_symbols) | ||
sorted_proxy_inputs = [v.proxy for v in sorted(r.inputs, key=lambda x: x.proxy.name)] | ||
sorted_proxy_outputs = [v.proxy for v in sorted(r.outputs, key=lambda x: x.proxy.name)] | ||
region_trace.args = sorted_proxy_inputs | ||
region_trace.kwargs = {} | ||
region_trace.bound_symbols.append(thunder.prims.python_return.bind(sorted_proxy_outputs, output=())) | ||
region_trace = thunder.executors.passes.dce(region_trace) | ||
|
||
def torch_interpreted_func(*args): | ||
return eval_trace(region_trace, *args, symbol_mapper=to_torch_translator) | ||
|
||
torch_trace = thunder.trace(inline_trace=False)(torch_interpreted_func, *sorted_proxy_inputs) | ||
|
||
# Remove subsymbols for readability of the trace | ||
for bsym in torch_trace.bound_symbols: | ||
bsym.subsymbols = [] | ||
|
||
torch_traces.append(torch_trace) | ||
|
||
return torch_traces | ||
|
||
|
||
def wrap_for_benchmark(fn): | ||
@wraps(fn) | ||
def fn_(*args, **kwargs): | ||
torch.cuda.synchronize() | ||
result = fn(*args, **kwargs) | ||
torch.cuda.synchronize() | ||
return result | ||
|
||
return fn_ | ||
|
||
|
||
def backward_only(torch_trace: TraceCtx, jit_fn: Callable, fw_setup_fn: Callable): | ||
fn = torch_trace.python_callable(include_no_grad=False) | ||
jfn = jit_fn(fn) | ||
args, kwargs = fw_setup_fn() | ||
result = jfn(*args, **kwargs) | ||
result = thunder.core.utils.sequencify(result) | ||
|
||
def backward_fn(*args, **kwargs): | ||
for a in args: | ||
a.grad = None | ||
|
||
torch.autograd.backward(result, args, retain_graph=True) | ||
|
||
return backward_fn | ||
|
||
|
||
executor_names = { | ||
torch_executor: "eager", | ||
thunder_executor: "thunder", | ||
# Inductor with the default GEMM backend (Aten) | ||
torch_compile_executor: "inductor", | ||
# Inductor with autotuning between Aten and CUTLASS or Triton as the GEMM backend | ||
inductor_triton_executor: "inductor+triton_gemm", | ||
# At the moment, I was not able to get the CUTLASS backend to work with the Inductor executor | ||
# torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised: | ||
# LoweringException: ErrorFromChoice: Error in function: cuda_cutlass_gemm | ||
# inductor_cutlass_executor: "inductor+cutlass_gemm", | ||
} | ||
|
||
|
||
@dataclass | ||
class TraceInfo: | ||
config_name: str | ||
region_idx: int | ||
trace: TraceCtx | ||
|
||
def __repr__(self): | ||
return f"TraceInfo(config_name={self.config_name}, region_idx={self.region_idx})" | ||
|
||
|
||
# litgpt_traces = [ | ||
# TraceInfo(name, i, trace) for name in CONFIG_NAMES for i, trace in enumerate(make_torch_traces_for_config(name)) | ||
# ] | ||
Comment on lines
+206
to
+208
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this still needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. List comprehensions are easier to read for me than the for-loop below. |
||
|
||
# Rewrite list comprehension above to include printing progress bar | ||
litgpt_traces = [] | ||
for j, name in enumerate(CONFIG_NAMES): | ||
print(f"Constructing benchmark cases for config: {name} ({j + 1}/{len(CONFIG_NAMES)})") | ||
traces = make_torch_traces_for_config(name) | ||
for i, trace in enumerate(traces): | ||
litgpt_traces.append(TraceInfo(name, i, trace)) | ||
|
||
# Now we have a list of torch_traces that are ready to be benchmarked | ||
trace_executor_pairs = list(product(litgpt_traces, (executor_names.keys()))) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"info, executor", | ||
trace_executor_pairs, | ||
ids=[ | ||
f"{info.config_name}_region{info.region_idx}_{executor_names[executor]}" | ||
for info, executor in trace_executor_pairs | ||
], | ||
) | ||
@pytest.mark.benchmark(group="forward") | ||
def test_litgpt_forward(benchmark, info, executor): | ||
torch_trace = info.trace | ||
|
||
def setup(): | ||
args = [] | ||
for a in torch_trace.args: | ||
torch_dtype = thunder.torch.to_torch_dtype(a.dtype) | ||
torch_device = thunder.core.devices.to_torch_device(a.device) | ||
is_float = isinstance(a.dtype, thunder.core.dtypes.floating) | ||
low = 0 if not is_float else None | ||
args.append(make_tensor(a.shape, dtype=torch_dtype, device=torch_device, requires_grad=is_float, low=low)) | ||
return args, {} | ||
|
||
fn = torch_trace.python_callable(include_no_grad=False) | ||
fn = executor(fn) | ||
fn = wrap_for_benchmark(fn) | ||
|
||
benchmark.pedantic(fn, setup=setup, rounds=20, warmup_rounds=1) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"info, executor", | ||
trace_executor_pairs, | ||
ids=[ | ||
f"{info.config_name}_region{info.region_idx}_{executor_names[executor]}" | ||
for info, executor in trace_executor_pairs | ||
], | ||
) | ||
@pytest.mark.benchmark(group="backward") | ||
def test_litgpt_backward(benchmark, info, executor): | ||
torch_trace = info.trace | ||
|
||
def fw_setup(): | ||
args = [] | ||
for a in torch_trace.args: | ||
torch_dtype = thunder.torch.to_torch_dtype(a.dtype) | ||
torch_device = thunder.core.devices.to_torch_device(a.device) | ||
is_float = isinstance(a.dtype, thunder.core.dtypes.floating) | ||
low = 0 if not is_float else None | ||
args.append(make_tensor(a.shape, dtype=torch_dtype, device=torch_device, requires_grad=is_float, low=low)) | ||
return args, {} | ||
|
||
def bw_setup(): | ||
args = [] | ||
for a in torch_trace.output: | ||
torch_dtype = thunder.torch.to_torch_dtype(a.dtype) | ||
torch_device = thunder.core.devices.to_torch_device(a.device) | ||
is_float = isinstance(a.dtype, thunder.core.dtypes.floating) | ||
low = 0 if not is_float else None | ||
args.append(make_tensor(a.shape, dtype=torch_dtype, device=torch_device, requires_grad=False, low=low)) | ||
return args, {} | ||
|
||
fn = backward_only(torch_trace, executor, fw_setup) | ||
fn = wrap_for_benchmark(fn) | ||
|
||
benchmark.pedantic(fn, setup=bw_setup, rounds=20, warmup_rounds=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Uncommenting this would force generating benchmark cases just for this Llama 2 7B config.