diff --git a/thunder/benchmarks/__init__.py b/thunder/benchmarks/__init__.py index 57ced4ef6e..98e3782d50 100644 --- a/thunder/benchmarks/__init__.py +++ b/thunder/benchmarks/__init__.py @@ -707,6 +707,32 @@ def torch_compile_executor(fn: Callable) -> Callable: return torch.compile(fn) +def inductor_gemm_executor(fn: Callable, gemm_backend: str) -> Callable: + import torch._inductor.config + + torch.backends.cuda.matmul.allow_tf32 = True + torch._dynamo.reset() + + compiled = torch.compile(fn) + + def wrapper(*args, **kwargs): + old = torch._inductor.config.max_autotune_gemm_backends + old1 = torch._inductor.config.max_autotune_gemm + try: + torch._inductor.config.max_autotune_gemm_backends = gemm_backend + torch._inductor.config.max_autotune_gemm = True + return compiled(*args, **kwargs) + finally: + torch._inductor.config.max_autotune_gemm_backends = old + torch._inductor.config.max_autotune_gemm = old1 + + return wrapper + + +inductor_cutlass_executor = partial(inductor_gemm_executor, gemm_backend="CUTLASS") +inductor_triton_executor = partial(inductor_gemm_executor, gemm_backend="TRITON") + + def thunder_torch_executor(fn: Callable) -> Callable: torch.backends.cuda.matmul.allow_tf32 = True return thunder.jit(fn, executors=[thunder.pytorch_executor]) diff --git a/thunder/benchmarks/litgpt_chunks.py b/thunder/benchmarks/litgpt_chunks.py new file mode 100644 index 0000000000..f20fb2bc67 --- /dev/null +++ b/thunder/benchmarks/litgpt_chunks.py @@ -0,0 +1,286 @@ +from collections.abc import Callable, Sequence +from dataclasses import dataclass +from functools import wraps +from itertools import groupby, product + +import pytest + +import torch + +from litgpt import Config, GPT +from litgpt.config import configs + +import thunder + +from thunder.benchmarks import ( + inductor_cutlass_executor, + inductor_triton_executor, + thunder_executor, + torch_compile_executor, + torch_executor, +) +from thunder.common import CompileData, CompileStats +from thunder.core.compile_data import set_compile_data_and_stats + +from thunder.core.jit_ext import thunder_general_jit +from thunder.core.langctxs import set_langctx +from thunder.core.trace import TraceCtx +from thunder.core.transforms import eval_trace +from thunder.executors.torch_compile import to_torch_translator + +from thunder.tests.make_tensor import make_tensor + +BATCH_SIZE = 2 +CONFIG_NAMES = list(sorted(c["name"] for c in configs)) +# CONFIG_NAMES = ["Llama-2-7b-hf",] + +# There are many configurations but only the following parameters affect the Linear layers in the model: +# - n_embd +# - padded_vocab_size +# - n_head +# - n_query_groups +# - head_size +# - bias +# - intermediate_size +# - n_expert +# Let's select only the configurations that differ in these parameters +unique_config_names = {} +for config_name in CONFIG_NAMES: + config = Config.from_name(config_name) + key = tuple( + getattr(config, k) + for k in ( + "n_embd", + "padded_vocab_size", + "n_head", + "n_query_groups", + "head_size", + "bias", + "intermediate_size", + "n_expert", + ) + ) + unique_config_names[key] = config.name + +CONFIG_NAMES = list(sorted(unique_config_names.values())) + +# We will skip the Mixtral MoE configs because they are not supported by the current implementation +# See https://github.com/Lightning-AI/lightning-thunder/issues/124 +CONFIG_NAMES = [name for name in CONFIG_NAMES if "mixtral" not in name.lower()] + + +def make_torch_traces_for_config(name: str): + config = Config.from_name(name) + # Two layers is enough to expose the fusing opportunity of the following network boundaries: + # - Embedding layer -> 0th Transformer layer + # - Last Transformer layer -> Output layer + # - End of the Transformer layer -> Beginning of the Transformer layer + config.n_layer = 2 + + model = GPT(config).to(dtype=torch.bfloat16, device="cuda") + input_shape = (BATCH_SIZE, config.block_size) + x = torch.randint(0, config.vocab_size, input_shape, dtype=torch.int64, device="cuda") + + # Acquire the initial trace + # We could use thunder.jit here, but we want to not execute the compiled function + # and instead only get the initial trace before any transformations + # jit_model = thunder.jit(model) + # out = jit_model(x) + # trace = thunder.last_traces(jit_model)[0] + + # We need to set up contexts that are usually set up by the thunder.jit decorator + cd = CompileData(fn=model, disable_preprocessing=True, executor_lookasides={}) + cs = CompileStats() + set_langctx(thunder.torch.torchctx) + set_compile_data_and_stats(cd, cs) + thunder._cache_info_ctx.set({}) + jit_results = thunder_general_jit(model, (x,), {}, sharp_edges=thunder.core.options.SHARP_EDGES_OPTIONS.ALLOW) + prologue = jit_results.prologue_trace + trace = jit_results.computation_trace + epilogue = jit_results.epilogue_trace + + # Remove subsymbols for readability of the trace + for bsym in trace.bound_symbols: + bsym.subsymbols = [] + + producers, consumers = thunder.core.utils.producers_and_consumers(trace) + + # Remove unpacking prims so that they can be identified as inputs of the first chunk + trace.bound_symbols = [bsym for bsym in trace.bound_symbols if bsym.sym.id != thunder.prims.unpack_trivial.id] + + # Remove return prim so that it can be identified as the output of the last chunk + assert trace.bound_symbols.pop().sym.id == thunder.prims.python_return.id + + # We want to split the trace into chunks of network between the scaled dot-product attention calls + assert ( + len([bsym for bsym in trace.bound_symbols if bsym.sym.id == thunder.torch.scaled_dot_product_attention.id]) + == config.n_layer + ) + + # This is going to be our delimiter for splitting the trace into chunks + thunder_sdpa = thunder.torch.scaled_dot_product_attention + chunks = list(list(g) for k, g in groupby(trace.bound_symbols, key=lambda x: x.sym.id == thunder_sdpa.id) if not k) + + # Now we need to convert the chunks into a list of functions + regions = [thunder.executors.utils.Region(producers, consumers, chunk) for chunk in chunks] + + # After this point, we will have a list of regions that are represented as regular PyTorch functions + # We can acquire the Python functions by calling .python_callable() on each "torch_trace" object + torch_traces = [] + for r in regions: + # Here we construct a trace that will be used to compile the function + region_trace = TraceCtx(None) + region_trace.bound_symbols = list(r.bound_symbols) + sorted_proxy_inputs = [v.proxy for v in sorted(r.inputs, key=lambda x: x.proxy.name)] + sorted_proxy_outputs = [v.proxy for v in sorted(r.outputs, key=lambda x: x.proxy.name)] + region_trace.args = sorted_proxy_inputs + region_trace.kwargs = {} + region_trace.bound_symbols.append(thunder.prims.python_return.bind(sorted_proxy_outputs, output=())) + region_trace = thunder.executors.passes.dce(region_trace) + + def torch_interpreted_func(*args): + return eval_trace(region_trace, *args, symbol_mapper=to_torch_translator) + + torch_trace = thunder.trace(inline_trace=False)(torch_interpreted_func, *sorted_proxy_inputs) + + # Remove subsymbols for readability of the trace + for bsym in torch_trace.bound_symbols: + bsym.subsymbols = [] + + torch_traces.append(torch_trace) + + return torch_traces + + +def wrap_for_benchmark(fn): + @wraps(fn) + def fn_(*args, **kwargs): + torch.cuda.synchronize() + result = fn(*args, **kwargs) + torch.cuda.synchronize() + return result + + return fn_ + + +def backward_only(torch_trace: TraceCtx, jit_fn: Callable, fw_setup_fn: Callable): + fn = torch_trace.python_callable(include_no_grad=False) + jfn = jit_fn(fn) + args, kwargs = fw_setup_fn() + result = jfn(*args, **kwargs) + result = thunder.core.utils.sequencify(result) + + def backward_fn(*args, **kwargs): + for a in args: + a.grad = None + + torch.autograd.backward(result, args, retain_graph=True) + + return backward_fn + + +executor_names = { + torch_executor: "eager", + thunder_executor: "thunder", + # Inductor with the default GEMM backend (Aten) + torch_compile_executor: "inductor", + # Inductor with autotuning between Aten and CUTLASS or Triton as the GEMM backend + inductor_triton_executor: "inductor+triton_gemm", + # At the moment, I was not able to get the CUTLASS backend to work with the Inductor executor + # torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised: + # LoweringException: ErrorFromChoice: Error in function: cuda_cutlass_gemm + # inductor_cutlass_executor: "inductor+cutlass_gemm", +} + + +@dataclass +class TraceInfo: + config_name: str + region_idx: int + trace: TraceCtx + + def __repr__(self): + return f"TraceInfo(config_name={self.config_name}, region_idx={self.region_idx})" + + +# litgpt_traces = [ +# TraceInfo(name, i, trace) for name in CONFIG_NAMES for i, trace in enumerate(make_torch_traces_for_config(name)) +# ] + +# Rewrite list comprehension above to include printing progress bar +litgpt_traces = [] +for j, name in enumerate(CONFIG_NAMES): + print(f"Constructing benchmark cases for config: {name} ({j + 1}/{len(CONFIG_NAMES)})") + traces = make_torch_traces_for_config(name) + for i, trace in enumerate(traces): + litgpt_traces.append(TraceInfo(name, i, trace)) + +# Now we have a list of torch_traces that are ready to be benchmarked +trace_executor_pairs = list(product(litgpt_traces, (executor_names.keys()))) + + +@pytest.mark.parametrize( + "info, executor", + trace_executor_pairs, + ids=[ + f"{info.config_name}_region{info.region_idx}_{executor_names[executor]}" + for info, executor in trace_executor_pairs + ], +) +@pytest.mark.benchmark(group="forward") +def test_litgpt_forward(benchmark, info, executor): + torch_trace = info.trace + + def setup(): + args = [] + for a in torch_trace.args: + torch_dtype = thunder.torch.to_torch_dtype(a.dtype) + torch_device = thunder.core.devices.to_torch_device(a.device) + is_float = isinstance(a.dtype, thunder.core.dtypes.floating) + low = 0 if not is_float else None + args.append(make_tensor(a.shape, dtype=torch_dtype, device=torch_device, requires_grad=is_float, low=low)) + return args, {} + + fn = torch_trace.python_callable(include_no_grad=False) + fn = executor(fn) + fn = wrap_for_benchmark(fn) + + benchmark.pedantic(fn, setup=setup, rounds=20, warmup_rounds=1) + + +@pytest.mark.parametrize( + "info, executor", + trace_executor_pairs, + ids=[ + f"{info.config_name}_region{info.region_idx}_{executor_names[executor]}" + for info, executor in trace_executor_pairs + ], +) +@pytest.mark.benchmark(group="backward") +def test_litgpt_backward(benchmark, info, executor): + torch_trace = info.trace + + def fw_setup(): + args = [] + for a in torch_trace.args: + torch_dtype = thunder.torch.to_torch_dtype(a.dtype) + torch_device = thunder.core.devices.to_torch_device(a.device) + is_float = isinstance(a.dtype, thunder.core.dtypes.floating) + low = 0 if not is_float else None + args.append(make_tensor(a.shape, dtype=torch_dtype, device=torch_device, requires_grad=is_float, low=low)) + return args, {} + + def bw_setup(): + args = [] + for a in torch_trace.output: + torch_dtype = thunder.torch.to_torch_dtype(a.dtype) + torch_device = thunder.core.devices.to_torch_device(a.device) + is_float = isinstance(a.dtype, thunder.core.dtypes.floating) + low = 0 if not is_float else None + args.append(make_tensor(a.shape, dtype=torch_dtype, device=torch_device, requires_grad=False, low=low)) + return args, {} + + fn = backward_only(torch_trace, executor, fw_setup) + fn = wrap_for_benchmark(fn) + + benchmark.pedantic(fn, setup=bw_setup, rounds=20, warmup_rounds=1) diff --git a/thunder/core/trace.py b/thunder/core/trace.py index dc2029e79c..219f5af0a3 100644 --- a/thunder/core/trace.py +++ b/thunder/core/trace.py @@ -306,7 +306,7 @@ def python_ctx(self) -> dict: # TODO issue "Add type annotations to Python function produced by traces" # Consider extending the signature with type information, in particular the # the type information of the return value might be interesting - def python(self, *, print_depth: int = 1) -> str: + def python(self, *, print_depth: int = 1, include_no_grad=True) -> str: token = set_tracectx(self) try: @@ -364,7 +364,8 @@ def keyfn(class_or_module: type | ModuleType) -> str: program.append("") # Prints the signature and the no_grad context (for when calling torch operations) - program.append("@torch.no_grad()") + if include_no_grad: + program.append("@torch.no_grad()") # Prints the signature and the no_autocast context program.append("@no_autocast()") program.append(signature_str) @@ -397,14 +398,14 @@ def keyfn(class_or_module: type | ModuleType) -> str: # Returns a Python callable that executes the trace # TODO issue "Create a mechanism for freezing TraceCtx objects" # Create a mechanism for freezing traces and cache the compilation - def python_callable(self, *, global_dicts: None | dict = None) -> Callable: + def python_callable(self, *, global_dicts: None | dict = None, include_no_grad=True) -> Callable: python_str: str # Writes the program to allow it to be edited before execution path: None | str = _get_execution_file() if path is not None: f = open(path, "w") - f.write(self.python()) + f.write(self.python(include_no_grad=include_no_grad)) f.close() input(f"Trace written to {os.path.realpath(path)} Press Any key to execute it") @@ -412,7 +413,7 @@ def python_callable(self, *, global_dicts: None | dict = None) -> Callable: with open(path) as file: python_str = file.read() else: - python_str = self.python() + python_str = self.python(include_no_grad=include_no_grad) ctx = self.python_ctx() if global_dicts is not None: diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 6d4d1cd049..80cf209aab 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -13,7 +13,7 @@ import opt_einsum # Initializes the language context -from thunder.torch.langctx import register_method, register_property +from thunder.torch.langctx import register_method, register_property, torchctx import thunder.clang as clang import thunder.core.devices as devices