Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a benchmark for portions of LitGPT model other than SDPA #148

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
552f415
Allow skipping no_grad decorator
IvanYashchuk Apr 9, 2024
79b136b
Add torchctx to thunder.torch
IvanYashchuk Apr 9, 2024
3427326
Add initial version of chunking litgpt network and benchmarking
IvanYashchuk Apr 9, 2024
9d02ea7
Add parametrization over the config name
IvanYashchuk Apr 9, 2024
26493fa
Use TraceInfo dataclass instead of plain tuples
IvanYashchuk Apr 9, 2024
e87c95d
batch_size -> BATCH_SIZE
IvanYashchuk Apr 9, 2024
8b51449
Generate cases for all litgpt configs
IvanYashchuk Apr 9, 2024
f08cf10
Add region index to TraceInfo
IvanYashchuk Apr 9, 2024
ba9f360
Move make_tensor import to the header
IvanYashchuk Apr 9, 2024
2706b69
Skip mixtral
IvanYashchuk Apr 9, 2024
d189fda
Print progress
IvanYashchuk Apr 9, 2024
5f70895
Filter out non unique configs
IvanYashchuk Apr 9, 2024
1b73e3c
Add inductor with gemm autotuning
IvanYashchuk Apr 9, 2024
995229b
Use CONFIG_NAME when constructing unique_config_names
IvanYashchuk Apr 9, 2024
ce227ad
formatting
IvanYashchuk Apr 9, 2024
101fae4
Skip inductor+cutlass as it's raising an error
IvanYashchuk Apr 9, 2024
2cbc8a4
Rename executors to include gemm in the name
IvanYashchuk Apr 9, 2024
e61d501
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 9, 2024
377f675
Merge branch 'main' into litgpt-chunks-bench
t-vi Apr 11, 2024
9cb3ca0
The type of general_jit return was changed in https://github.com/Ligh…
IvanYashchuk Apr 17, 2024
2799f11
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2024
e351685
Merge remote-tracking branch 'upstream/main' into litgpt-chunks-bench
IvanYashchuk Apr 18, 2024
fe64f0e
Remove ATEN from gemm_backend
IvanYashchuk Apr 18, 2024
57cec8a
Split benchmarks into forward and backward; make pytest-benchmark groups
IvanYashchuk Apr 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions thunder/benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,32 @@ def torch_compile_executor(fn: Callable) -> Callable:
return torch.compile(fn)


def inductor_gemm_executor(fn: Callable, gemm_backend: str) -> Callable:
import torch._inductor.config

torch.backends.cuda.matmul.allow_tf32 = True
torch._dynamo.reset()

compiled = torch.compile(fn)

def wrapper(*args, **kwargs):
old = torch._inductor.config.max_autotune_gemm_backends
old1 = torch._inductor.config.max_autotune_gemm
try:
torch._inductor.config.max_autotune_gemm_backends = gemm_backend
torch._inductor.config.max_autotune_gemm = True
return compiled(*args, **kwargs)
finally:
torch._inductor.config.max_autotune_gemm_backends = old
torch._inductor.config.max_autotune_gemm = old1

return wrapper


inductor_cutlass_executor = partial(inductor_gemm_executor, gemm_backend="CUTLASS")
inductor_triton_executor = partial(inductor_gemm_executor, gemm_backend="TRITON")


def thunder_torch_executor(fn: Callable) -> Callable:
torch.backends.cuda.matmul.allow_tf32 = True
return thunder.jit(fn, executors=[thunder.pytorch_executor])
Expand Down
286 changes: 286 additions & 0 deletions thunder/benchmarks/litgpt_chunks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from functools import wraps
from itertools import groupby, product

import pytest

import torch

from litgpt import Config, GPT
from litgpt.config import configs

import thunder

from thunder.benchmarks import (
inductor_cutlass_executor,
inductor_triton_executor,
thunder_executor,
torch_compile_executor,
torch_executor,
)
from thunder.common import CompileData, CompileStats
from thunder.core.compile_data import set_compile_data_and_stats

from thunder.core.jit_ext import thunder_general_jit
from thunder.core.langctxs import set_langctx
from thunder.core.trace import TraceCtx
from thunder.core.transforms import eval_trace
from thunder.executors.torch_compile import to_torch_translator

from thunder.tests.make_tensor import make_tensor

BATCH_SIZE = 2
CONFIG_NAMES = list(sorted(c["name"] for c in configs))
# CONFIG_NAMES = ["Llama-2-7b-hf",]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uncommenting this would force generating benchmark cases just for this Llama 2 7B config.


# There are many configurations but only the following parameters affect the Linear layers in the model:
# - n_embd
# - padded_vocab_size
# - n_head
# - n_query_groups
# - head_size
# - bias
# - intermediate_size
# - n_expert
# Let's select only the configurations that differ in these parameters
unique_config_names = {}
for config_name in CONFIG_NAMES:
config = Config.from_name(config_name)
key = tuple(
getattr(config, k)
for k in (
"n_embd",
"padded_vocab_size",
"n_head",
"n_query_groups",
"head_size",
"bias",
"intermediate_size",
"n_expert",
)
)
unique_config_names[key] = config.name

CONFIG_NAMES = list(sorted(unique_config_names.values()))

# We will skip the Mixtral MoE configs because they are not supported by the current implementation
# See https://github.com/Lightning-AI/lightning-thunder/issues/124
CONFIG_NAMES = [name for name in CONFIG_NAMES if "mixtral" not in name.lower()]


def make_torch_traces_for_config(name: str):
config = Config.from_name(name)
# Two layers is enough to expose the fusing opportunity of the following network boundaries:
# - Embedding layer -> 0th Transformer layer
# - Last Transformer layer -> Output layer
# - End of the Transformer layer -> Beginning of the Transformer layer
config.n_layer = 2

model = GPT(config).to(dtype=torch.bfloat16, device="cuda")
input_shape = (BATCH_SIZE, config.block_size)
x = torch.randint(0, config.vocab_size, input_shape, dtype=torch.int64, device="cuda")

# Acquire the initial trace
# We could use thunder.jit here, but we want to not execute the compiled function
# and instead only get the initial trace before any transformations
# jit_model = thunder.jit(model)
# out = jit_model(x)
# trace = thunder.last_traces(jit_model)[0]

# We need to set up contexts that are usually set up by the thunder.jit decorator
cd = CompileData(fn=model, disable_preprocessing=True, executor_lookasides={})
cs = CompileStats()
set_langctx(thunder.torch.torchctx)
set_compile_data_and_stats(cd, cs)
thunder._cache_info_ctx.set({})
jit_results = thunder_general_jit(model, (x,), {}, sharp_edges=thunder.core.options.SHARP_EDGES_OPTIONS.ALLOW)
prologue = jit_results.prologue_trace
trace = jit_results.computation_trace
epilogue = jit_results.epilogue_trace

# Remove subsymbols for readability of the trace
for bsym in trace.bound_symbols:
bsym.subsymbols = []

producers, consumers = thunder.core.utils.producers_and_consumers(trace)

# Remove unpacking prims so that they can be identified as inputs of the first chunk
trace.bound_symbols = [bsym for bsym in trace.bound_symbols if bsym.sym.id != thunder.prims.unpack_trivial.id]

# Remove return prim so that it can be identified as the output of the last chunk
assert trace.bound_symbols.pop().sym.id == thunder.prims.python_return.id

# We want to split the trace into chunks of network between the scaled dot-product attention calls
assert (
len([bsym for bsym in trace.bound_symbols if bsym.sym.id == thunder.torch.scaled_dot_product_attention.id])
== config.n_layer
)

# This is going to be our delimiter for splitting the trace into chunks
thunder_sdpa = thunder.torch.scaled_dot_product_attention
chunks = list(list(g) for k, g in groupby(trace.bound_symbols, key=lambda x: x.sym.id == thunder_sdpa.id) if not k)

# Now we need to convert the chunks into a list of functions
regions = [thunder.executors.utils.Region(producers, consumers, chunk) for chunk in chunks]

# After this point, we will have a list of regions that are represented as regular PyTorch functions
# We can acquire the Python functions by calling .python_callable() on each "torch_trace" object
torch_traces = []
for r in regions:
# Here we construct a trace that will be used to compile the function
region_trace = TraceCtx(None)
region_trace.bound_symbols = list(r.bound_symbols)
sorted_proxy_inputs = [v.proxy for v in sorted(r.inputs, key=lambda x: x.proxy.name)]
sorted_proxy_outputs = [v.proxy for v in sorted(r.outputs, key=lambda x: x.proxy.name)]
region_trace.args = sorted_proxy_inputs
region_trace.kwargs = {}
region_trace.bound_symbols.append(thunder.prims.python_return.bind(sorted_proxy_outputs, output=()))
region_trace = thunder.executors.passes.dce(region_trace)

def torch_interpreted_func(*args):
return eval_trace(region_trace, *args, symbol_mapper=to_torch_translator)

torch_trace = thunder.trace(inline_trace=False)(torch_interpreted_func, *sorted_proxy_inputs)

# Remove subsymbols for readability of the trace
for bsym in torch_trace.bound_symbols:
bsym.subsymbols = []

torch_traces.append(torch_trace)

return torch_traces


def wrap_for_benchmark(fn):
@wraps(fn)
def fn_(*args, **kwargs):
torch.cuda.synchronize()
result = fn(*args, **kwargs)
torch.cuda.synchronize()
return result

return fn_


def backward_only(torch_trace: TraceCtx, jit_fn: Callable, fw_setup_fn: Callable):
fn = torch_trace.python_callable(include_no_grad=False)
jfn = jit_fn(fn)
args, kwargs = fw_setup_fn()
result = jfn(*args, **kwargs)
result = thunder.core.utils.sequencify(result)

def backward_fn(*args, **kwargs):
for a in args:
a.grad = None

torch.autograd.backward(result, args, retain_graph=True)

return backward_fn


executor_names = {
torch_executor: "eager",
thunder_executor: "thunder",
# Inductor with the default GEMM backend (Aten)
torch_compile_executor: "inductor",
# Inductor with autotuning between Aten and CUTLASS or Triton as the GEMM backend
inductor_triton_executor: "inductor+triton_gemm",
# At the moment, I was not able to get the CUTLASS backend to work with the Inductor executor
# torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
# LoweringException: ErrorFromChoice: Error in function: cuda_cutlass_gemm
# inductor_cutlass_executor: "inductor+cutlass_gemm",
}


@dataclass
class TraceInfo:
config_name: str
region_idx: int
trace: TraceCtx

def __repr__(self):
return f"TraceInfo(config_name={self.config_name}, region_idx={self.region_idx})"


# litgpt_traces = [
# TraceInfo(name, i, trace) for name in CONFIG_NAMES for i, trace in enumerate(make_torch_traces_for_config(name))
# ]
Comment on lines +206 to +208
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

List comprehensions are easier to read for me than the for-loop below.
I'll remove this of course.


# Rewrite list comprehension above to include printing progress bar
litgpt_traces = []
for j, name in enumerate(CONFIG_NAMES):
print(f"Constructing benchmark cases for config: {name} ({j + 1}/{len(CONFIG_NAMES)})")
traces = make_torch_traces_for_config(name)
for i, trace in enumerate(traces):
litgpt_traces.append(TraceInfo(name, i, trace))

# Now we have a list of torch_traces that are ready to be benchmarked
trace_executor_pairs = list(product(litgpt_traces, (executor_names.keys())))


@pytest.mark.parametrize(
"info, executor",
trace_executor_pairs,
ids=[
f"{info.config_name}_region{info.region_idx}_{executor_names[executor]}"
for info, executor in trace_executor_pairs
],
)
@pytest.mark.benchmark(group="forward")
def test_litgpt_forward(benchmark, info, executor):
torch_trace = info.trace

def setup():
args = []
for a in torch_trace.args:
torch_dtype = thunder.torch.to_torch_dtype(a.dtype)
torch_device = thunder.core.devices.to_torch_device(a.device)
is_float = isinstance(a.dtype, thunder.core.dtypes.floating)
low = 0 if not is_float else None
args.append(make_tensor(a.shape, dtype=torch_dtype, device=torch_device, requires_grad=is_float, low=low))
return args, {}

fn = torch_trace.python_callable(include_no_grad=False)
fn = executor(fn)
fn = wrap_for_benchmark(fn)

benchmark.pedantic(fn, setup=setup, rounds=20, warmup_rounds=1)


@pytest.mark.parametrize(
"info, executor",
trace_executor_pairs,
ids=[
f"{info.config_name}_region{info.region_idx}_{executor_names[executor]}"
for info, executor in trace_executor_pairs
],
)
@pytest.mark.benchmark(group="backward")
def test_litgpt_backward(benchmark, info, executor):
torch_trace = info.trace

def fw_setup():
args = []
for a in torch_trace.args:
torch_dtype = thunder.torch.to_torch_dtype(a.dtype)
torch_device = thunder.core.devices.to_torch_device(a.device)
is_float = isinstance(a.dtype, thunder.core.dtypes.floating)
low = 0 if not is_float else None
args.append(make_tensor(a.shape, dtype=torch_dtype, device=torch_device, requires_grad=is_float, low=low))
return args, {}

def bw_setup():
args = []
for a in torch_trace.output:
torch_dtype = thunder.torch.to_torch_dtype(a.dtype)
torch_device = thunder.core.devices.to_torch_device(a.device)
is_float = isinstance(a.dtype, thunder.core.dtypes.floating)
low = 0 if not is_float else None
args.append(make_tensor(a.shape, dtype=torch_dtype, device=torch_device, requires_grad=False, low=low))
return args, {}

fn = backward_only(torch_trace, executor, fw_setup)
fn = wrap_for_benchmark(fn)

benchmark.pedantic(fn, setup=bw_setup, rounds=20, warmup_rounds=1)
11 changes: 6 additions & 5 deletions thunder/core/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def python_ctx(self) -> dict:
# TODO issue "Add type annotations to Python function produced by traces"
# Consider extending the signature with type information, in particular the
# the type information of the return value might be interesting
def python(self, *, print_depth: int = 1) -> str:
def python(self, *, print_depth: int = 1, include_no_grad=True) -> str:
token = set_tracectx(self)

try:
Expand Down Expand Up @@ -364,7 +364,8 @@ def keyfn(class_or_module: type | ModuleType) -> str:
program.append("")

# Prints the signature and the no_grad context (for when calling torch operations)
program.append("@torch.no_grad()")
if include_no_grad:
program.append("@torch.no_grad()")
# Prints the signature and the no_autocast context
program.append("@no_autocast()")
program.append(signature_str)
Expand Down Expand Up @@ -397,22 +398,22 @@ def keyfn(class_or_module: type | ModuleType) -> str:
# Returns a Python callable that executes the trace
# TODO issue "Create a mechanism for freezing TraceCtx objects"
# Create a mechanism for freezing traces and cache the compilation
def python_callable(self, *, global_dicts: None | dict = None) -> Callable:
def python_callable(self, *, global_dicts: None | dict = None, include_no_grad=True) -> Callable:
python_str: str

# Writes the program to allow it to be edited before execution
path: None | str = _get_execution_file()
if path is not None:
f = open(path, "w")
f.write(self.python())
f.write(self.python(include_no_grad=include_no_grad))
f.close()

input(f"Trace written to {os.path.realpath(path)} Press Any key to execute it")

with open(path) as file:
python_str = file.read()
else:
python_str = self.python()
python_str = self.python(include_no_grad=include_no_grad)

ctx = self.python_ctx()
if global_dicts is not None:
Expand Down
2 changes: 1 addition & 1 deletion thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import opt_einsum

# Initializes the language context
from thunder.torch.langctx import register_method, register_property
from thunder.torch.langctx import register_method, register_property, torchctx

import thunder.clang as clang
import thunder.core.devices as devices
Expand Down
Loading