From a1bbce61c1a56881dd55ab2d73c5a1876a28bd3f Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Wed, 3 Jul 2024 14:57:47 +0900 Subject: [PATCH 01/10] fix some typos (#696) --- thunder/core/interpreter.py | 37 ++++++++++++++++++------------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/thunder/core/interpreter.py b/thunder/core/interpreter.py index 6570c2f302..8c8b117c8b 100644 --- a/thunder/core/interpreter.py +++ b/thunder/core/interpreter.py @@ -16,7 +16,7 @@ import traceback import weakref import torch -from typing import Any, Literal, NamedTuple, TypedDict +from typing import Any, Literal, TypedDict from collections.abc import Callable, Iterable, Iterator, Mapping, MutableMapping, Sequence, Set, Sized import collections import operator @@ -29,7 +29,6 @@ CodeType, CoroutineType, FrameType, - GetSetDescriptorType, FunctionType, MethodType, MethodDescriptorType, @@ -180,7 +179,7 @@ def track_items(self): self.has_item_tracking = True def register_proxy(self, proxy): - # note: the proxy is responsible for capturiing all the existing attributes/values + # note: the proxy is responsible for capturing all the existing attributes/values assert ( self.original_value is self.nothing ), "cannot proxy multiple times, please file an issue to discuss your use-case" @@ -272,7 +271,7 @@ def wrap(value: Any, /, *, provenance: ProvenanceRecord) -> WrappedValue: if cached is not None: potential_wrap = cached[0]() if potential_wrap is not None: - # Note: we want to cache mutuable objects to not run into trouble + # Note: we want to cache mutable objects to not run into trouble # with multiple accesses to the same. # As the cache only holds a weakref to the WrappedValue instance # one must persist WrappedValues once things are starting to be modified. @@ -368,7 +367,7 @@ def populate_item_wrappers(l): l.key_wrappers[k] = wk # or have those from an iteration of the input? return - raise NotImplementedError(f"populate item wrapppers for {type(l.value)}") + raise NotImplementedError(f"populate item wrappers for {type(l.value)}") # @@ -495,7 +494,7 @@ class ReturnLogItem(TypedDict): # To the interpreter User Exceptions are part of its normal operations. So we usually # don't use Python's exception mechanism to handle these. Instead the # interpreter mimics CPython's implementation of exception handling by -# defining analoguous structures (curexec and exception_stack) set by do_raise and returns +# defining analogous structures (curexec and exception_stack) set by do_raise and returns # INTERPRETER_SIGNALS.EXCEPTION_RAISED when running things that raised exceptions. # Here in particular: # - Handlers are "inside the interpreter", @@ -524,7 +523,7 @@ def __init__(self, record_history: bool, debug_log: None | StringIO): self.exception_stack = [None] # ts.exc_info is the top of the stack (self.exception_stack[-1]). # Note that most of the time, we are changing the self.exception_stack[-1] instead of popping/pushing exceptions. - # `exception_stack[-1] = ...` is the equivalent of assiging ts.exc_info->exc_type/value/traceback. + # `exception_stack[-1] = ...` is the equivalent of assigning ts.exc_info->exc_type/value/traceback. # ts.exc_state is exc_info (the bottom-most element of the stack(?)) # ts.curexc (in Python 3.10 as _type / _value / _traceback) is the exception currently being raised @@ -613,7 +612,7 @@ def get_current_user_source_location(self) -> tuple[str, Positions]: return frame.code.co_filename, frame.positions return None, None - # TODO Instead of appending to both the log and and interpreted_instructions we could + # TODO Instead of appending to both the log and interpreted_instructions we could # consider just appending to the log and then filtering to only instructions when # interpreted_instructions is accessed def record_interpreted_instruction(self, inst: dis.Instruction, /) -> InterpreterRuntimeCtx: @@ -1297,7 +1296,7 @@ def wrapping_wrapper(*args, **kwargs): return wrapping_wrapper -# Calling a function as an opaque function makes the interpeter not trace into it +# Calling a function as an opaque function makes the interpreter not trace into it @interpreter_needs_wrap def call_opaque(fn, /, *args, **kwargs): return fn(*args, **kwargs) @@ -2003,7 +2002,7 @@ def __next__(self): return res -# wrapper-handling lookasides for sequences and mutuable sequences. +# wrapper-handling lookasides for sequences and mutable sequences. # note: # - these are only for use when wrapping is enabled # - the methods (or the corresponding functions) will be registered @@ -2011,7 +2010,7 @@ def __next__(self): # called with wrapped values also for self and self.value will point # to the actual object... # -# TODO: maybe make these generic for sequences / mutuable sequence +# TODO: maybe make these generic for sequences / mutable sequence # https://docs.python.org/3/library/stdtypes.html#common-sequence-operations class SequenceWrapperMethods(WrappedValue): # NOTE! This is not actually a WrappedValue. However, @@ -2700,8 +2699,8 @@ def create_namedtuple(typename: str, field_names: str, **kwargs): } -# While mutuable sequences (lists) are created empty in __new__ and populated in __init__, -# immutuable sequences (tuples) are created with contents in __new__ and __init__ is a nop +# While mutable sequences (lists) are created empty in __new__ and populated in __init__, +# immutable sequences (tuples) are created with contents in __new__ and __init__ is a nop # (object.__init__, actually). def _tuple_new_provenance_tracking_lookaside(cls, iterable=(), /): new_tuple_type = cls.value @@ -3950,7 +3949,7 @@ def _end_async_for_handler_3_10( assert len(stack) >= try_block.level + 3 del stack[try_block.level + 3 :] - exc_type = frame.interpreter_stack.pop() # we ignore that and asume == type(exc_value) + exc_type = frame.interpreter_stack.pop() # we ignore that and assume == type(exc_value) exc_value = frame.interpreter_stack.pop() exc_traceback = frame.interpreter_stack.pop() if exc_value != None: @@ -4205,11 +4204,11 @@ def _import_name_handler( fromlist = stack.pop() level = stack.pop() - # relative imports rely on the the current module's name (from the frame stac?) + # relative imports rely on the current module's name (from the frame stac?) # but that isn't available if we use impl, so we resolve it here. if level > 0: # relative import # cannot do this in impl easily, but error handling? - # TODO: model this more after resove_name in CPython's Python/import.c + # TODO: model this more after resolve_name in CPython's Python/import.c def get_current_name(globals): package = globals.get("__package__") if package is None: @@ -5218,7 +5217,7 @@ def impl(tos): def _push_exc_info_handler(inst: dis.Instruction, /, stack: InterpreterStack, exception_stack: list, **kwargs) -> None: assert exception_stack top = stack.pop() - # CPython reads exc_info->exc_type/value/traceback, see RuntimeCtx inititalization of exception_stack for more info + # CPython reads exc_info->exc_type/value/traceback, see RuntimeCtx initialization of exception_stack for more info stack.append(exception_stack[-1]) stack.append(top) assert isinstance(top, BaseException) @@ -5869,7 +5868,7 @@ def _yield_value_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kw # (generator, async generator, coroutine), we define generic equivalents here # that take the interpreter frame and execute until the next yield point. # The way these functions work in Python is that objects are created and -# retruned either on invocation (Python <=3.10) or by the RETURN_GENERATOR +# returned either on invocation (Python <=3.10) or by the RETURN_GENERATOR # opcode (Python >= 3.11). def make_generator( frame: InterpreterFrame, @@ -6482,7 +6481,7 @@ def _run_frame( # NormalizeException ? # CPython sets exc_info->exc_type/value/traceback here - # see RuntimeCtx inititalization of exception_stack for more info + # see RuntimeCtx initialization of exception_stack for more info runtimectx.exception_stack[-1] = current_exception with frame.interpreter_stack.set_cur_instruction(PseudoInst.EXCEPTION_HANDLER): frame.interpreter_stack.append( From da23a0b0e9ad17568be8566ad839ca0b0e88043b Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Wed, 3 Jul 2024 16:26:28 +0200 Subject: [PATCH 02/10] use short device str for type_str (#702) --- thunder/core/proxies.py | 2 +- thunder/tests/test_core.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index b60f795670..df6dce4540 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -1302,7 +1302,7 @@ def __repr__(self): return f'<{type(self).__name__}(name="{self.name}", dtype={self.dtype}, shape={self.shape})>' def type_string(self): - return f"{self.device} {self.dtype.shortname()}{list(self.shape)}" + return f"{self.device.device_str()} {self.dtype.shortname()}{list(self.shape)}" # NOTE __getattr__ is overridden to support language-specific methods def __getattr__(self, attr: str, /): diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index fcf1b813fa..d63ebe7a14 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -2871,3 +2871,21 @@ def test_proxy_repr(): assert p.__repr__() == '' assert t.__repr__() == '' assert c.__repr__() == '' + + +def test_type_string(): + def fn(x): + return 2 * x + + jfn = thunder.jit(fn) + + a = torch.randn(2, 2) + + jfn(a) + + tr = thunder.last_traces(jfn)[0] + + assert tr.bound_symbols[1].sym == ltorch.mul + (pystr,) = tr.bound_symbols[1].python(0) + + assert pystr == 'result = ltorch.mul(2, x) # result: "cpu f32[2, 2]"' From 29c4a21280c0ccbc064e5aedd348ce5a041fd77f Mon Sep 17 00:00:00 2001 From: parthmannan <38387286+parthmannan@users.noreply.github.com> Date: Wed, 3 Jul 2024 07:42:31 -0700 Subject: [PATCH 03/10] Catch exceptions in throughput measurement utility (#697) --- thunder/benchmarks/benchmark_litgpt.py | 69 +++++++++++++++----------- 1 file changed, 40 insertions(+), 29 deletions(-) diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py index b037cef798..bad6ef7424 100644 --- a/thunder/benchmarks/benchmark_litgpt.py +++ b/thunder/benchmarks/benchmark_litgpt.py @@ -363,30 +363,37 @@ def pad_collate(batch): return train_dataloader def calculate_model_flops(self): - meta = torch.device("meta") device = self.device - self.device = meta - - # calculate flops on a meta-device model because we only care about the shapes and - # because the flops calculator installs hooks on the model - meta_model = self.init_model() - - x = torch.randint(0, 1, (self.micro_batch_size, meta_model.config.block_size), device=meta) - model_fwd = lambda: meta_model(x) - model_loss = lambda y: torch.nn.functional.cross_entropy( - y.reshape(-1, y.size(-1)), x.reshape(-1), ignore_index=-1 - ) - self.perf_metrics["model_flops"] = measure_flops(meta_model, model_fwd, model_loss) - - self.device = device + try: + meta = torch.device("meta") + self.device = meta + + # calculate flops on a meta-device model because we only care about the shapes and + # because the flops calculator installs hooks on the model + meta_model = self.init_model() + + x = torch.randint(0, 1, (self.micro_batch_size, meta_model.config.block_size), device=meta) + model_fwd = lambda: meta_model(x) + model_loss = lambda y: torch.nn.functional.cross_entropy( + y.reshape(-1, y.size(-1)), x.reshape(-1), ignore_index=-1 + ) + self.perf_metrics["model_flops"] = measure_flops(meta_model, model_fwd, model_loss) + finally: + self.device = device def train(self): t0 = None if global_rank in [0, None]: - # Calculate the model FLOPs - self.calculate_model_flops() - # Setup throughput Collection - self.throughput = Throughput(window_size=self.max_iters - self.warmup_iter, world_size=world_size) + try: + # Calculate the model FLOPs + self.calculate_model_flops() + # Setup throughput Collection + self.throughput = Throughput(window_size=self.max_iters - self.warmup_iter, world_size=world_size) + except: + self.throughput = None + print( + f"Model Flops/Throughput calculation failed for model {self.model_name}. Skipping throughput metric collection." + ) if self.skip_data_sync: data_sync_ctx = self.model.no_sync @@ -442,22 +449,26 @@ def train(self): f"iter {i}: loss {loss_item:.4f}, iter time: {(t1 - iter_t0) * 1000:.2f}ms, t: {input_ids.size(1)}" ) if i >= self.warmup_iter: - self.throughput.update( - time=(t1 - t0), - flops=self.perf_metrics["model_flops"], - batches=i, - samples=(i * self.micro_batch_size * self.gradient_accumulation_steps), - lengths=(i * self.micro_batch_size * self.gradient_accumulation_steps * self.config.block_size), - ) + if self.throughput: + self.throughput.update( + time=(t1 - t0), + flops=self.perf_metrics["model_flops"], + batches=i, + samples=(i * self.micro_batch_size * self.gradient_accumulation_steps), + lengths=( + i * self.micro_batch_size * self.gradient_accumulation_steps * self.config.block_size + ), + ) if global_rank in [0, None]: # print(f"Total time: {(t1 - t0):.2f}s") self.perf_metrics["average_iter_time"] = ((t1 - t0) * 1000) / (self.max_iters - self.warmup_iter) def add_perf_metrics(self): - metrics = self.throughput.compute() - self.perf_metrics["tokens_per_sec"] = metrics.get("items_per_sec", metrics["device/items_per_sec"]) - self.perf_metrics["model_flop_per_sec"] = metrics.get("flops_per_sec", metrics["device/flops_per_sec"]) + if self.throughput: + metrics = self.throughput.compute() + self.perf_metrics["tokens_per_sec"] = metrics.get("items_per_sec", metrics["device/items_per_sec"]) + self.perf_metrics["model_flop_per_sec"] = metrics.get("flops_per_sec", metrics["device/flops_per_sec"]) self.perf_metrics["memory_used_GB"] = torch.cuda.max_memory_allocated() / 1e9 def add_model_info_to_metrics(self): From 1d7a01d0adaca31cb3c1e85fff1646138ba35954 Mon Sep 17 00:00:00 2001 From: Vedaanta Agarwalla <142048820+vedaanta@users.noreply.github.com> Date: Wed, 3 Jul 2024 09:55:08 -0700 Subject: [PATCH 04/10] reenable some sdpa tests on pytorch dev (#691) Co-authored-by: Thomas Viehmann --- thunder/tests/opinfos.py | 16 ---------------- thunder/tests/test_cudnn_executor.py | 15 +-------------- thunder/tests/test_grad.py | 6 +----- thunder/tests/test_jit_general.py | 2 +- 4 files changed, 3 insertions(+), 36 deletions(-) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 8841b0d2aa..c4046778fd 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -7779,13 +7779,6 @@ def scaled_dot_product_attention_error_generator(op, device, **kwargs): "test_vjp_correctness", dtypes=(datatypes.float64,), ), - DecorateInfo( - pytest.mark.skip(reason="https://github.com/pytorch/pytorch/issues/129579"), - "test_cudnn_vs_torch_consistency", - dtypes=(datatypes.bfloat16, datatypes.float16, datatypes.float32), - devicetypes=(devices.DeviceType.CUDA,), - active_if=version_between(torch.__version__, min_ver="2.5.0a0", max_ver="2.5.0a99"), - ), ), ) nn_ops.append(sdpa_opinfo) @@ -7899,15 +7892,6 @@ def grad_scaled_dot_product_attention_sample_generator(op, device, dtype, requir # NOTE: NotImplementedError: Could not run 'aten::_scaled_dot_product_efficient_attention' with arguments from the 'CPU' backend. # NOTE: NotImplementedError: Could not run 'aten::_scaled_dot_product_efficient_attention_backward' with arguments from the 'CPU' backend devicetypes=(devices.DeviceType.CUDA,), - test_directives=( - DecorateInfo( - pytest.mark.skip(reason="https://github.com/pytorch/pytorch/issues/129579"), - "test_core_vs_torch_consistency", - dtypes=(datatypes.bfloat16, datatypes.float16, datatypes.float32), - devicetypes=(devices.DeviceType.CUDA,), - active_if=version_between(torch.__version__, min_ver="2.5.0a0", max_ver="2.5.0a99"), - ), - ), ) nn_ops.append(grad_sdpa_opinfo) diff --git a/thunder/tests/test_cudnn_executor.py b/thunder/tests/test_cudnn_executor.py index 909b77519a..59ef4f7a69 100644 --- a/thunder/tests/test_cudnn_executor.py +++ b/thunder/tests/test_cudnn_executor.py @@ -13,7 +13,7 @@ from thunder.core.utils import flatten_func from thunder.tests.framework import instantiate, NOTHING, ops, requiresCUDA, run_snippet, TorchExecutor, version_between from thunder.tests.make_tensor import make_tensor, make_tensor_like -from thunder.tests.opinfos import get_opinfo, OpInfo, DecorateInfo +from thunder.tests.opinfos import get_opinfo, OpInfo from thunder.tests.test_grad import _make_differentiable_wrapper cudnn = pytest.importorskip("cudnn") @@ -96,15 +96,6 @@ def grad_scaled_dot_product_attention_reference_generator(op, device, dtype, req thunder.dtypes.bfloat16, ), devicetypes=(devices.DeviceType.CUDA,), - test_directives=( - DecorateInfo( - pytest.mark.skip(reason="https://github.com/pytorch/pytorch/issues/129579"), - "test_cudnn_vs_torch_consistency", - dtypes=(dtypes.bfloat16, dtypes.float16, dtypes.float32), - devicetypes=(devices.DeviceType.CUDA,), - active_if=version_between(torch.__version__, min_ver="2.5.0a0", max_ver="2.5.0a99"), - ), - ), ) @@ -216,10 +207,6 @@ def test_cudnn_vs_torch_consistency(op, device, dtype, *_): LooseVersion(cudnn.backend_version_string()) < LooseVersion("8.9.5"), reason="cuDNN is required to be at least `8.9.5`", ) -@pytest.mark.skipif( - version_between(torch.__version__, min_ver="2.5.0a0", max_ver="2.5.0a99"), - reason="https://github.com/pytorch/pytorch/issues/129579", -) @pytest.mark.parametrize("may_cat_grad_qkv", (True, False), ids=("may-cat-grad-qkv", "never-cat-grad-qkv")) @pytest.mark.parametrize("dtype", grad_sdpa_cudnn_opinfo.dtypes(), ids=tuple(map(str, grad_sdpa_cudnn_opinfo.dtypes()))) def test_vjp_correctness_cudnn_sdpa(dtype, may_cat_grad_qkv): diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index f119af1eab..4997d8e6be 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -541,10 +541,6 @@ def test_vjp_correctness_index_put_manual(op, device, dtype, executor, comp): # NOTE Scaled_Dot_Product_Efficient_Attention_Backward does not support fp64 dtypes # RuntimeError: Only fp32, half & bf16 supported at the moment -@pytest.mark.skipif( - not version_between(torch.__version__, min_ver="2.5.0a0", max_ver="2.5.0a99"), - reason="https://github.com/pytorch/pytorch/issues/129579", -) @ops( (get_opinfo("grad_forward_scaled_dot_product_attention"),), supported_dtypes=(dtypes.float16, dtypes.bfloat16), @@ -553,7 +549,7 @@ def test_vjp_correctness_index_put_manual(op, device, dtype, executor, comp): def test_vjp_correctness_sdpa_manual(op, device, dtype, executor, comp): if version_between(torch.__version__, min_ver="2.5.0a0", max_ver="2.5.0a99"): raise pytest.skip( - "https://github.com/pytorch/pytorch/issues/129579", + "https://github.com/Lightning-AI/lightning-thunder/issues/703", ) for sample in op.sample_inputs(device, dtype, requires_grad=True): diff --git a/thunder/tests/test_jit_general.py b/thunder/tests/test_jit_general.py index fa079e04f4..37036a90b9 100644 --- a/thunder/tests/test_jit_general.py +++ b/thunder/tests/test_jit_general.py @@ -687,7 +687,7 @@ def test_litgpt_variants(name, device): @pytest.mark.skipif( version_between(torch.__version__, min_ver="2.5.0a0", max_ver="2.5.0a99"), - reason="https://github.com/pytorch/pytorch/issues/129579", + reason="https://github.com/Lightning-AI/lightning-thunder/issues/669", ) @skipif_not_pytorch_2_1 @pytest.mark.parametrize( From ab514fcb7836683c7b57f9bd7fe338a77d516bd2 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Thu, 4 Jul 2024 16:41:50 +0900 Subject: [PATCH 05/10] Add lookaside for `torch.autograd.function.Function.apply` (#707) --- thunder/core/jit_ext.py | 14 ++++++++++++++ thunder/tests/test_core.py | 26 ++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index d69e8796a9..ac9e127a3f 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -891,6 +891,20 @@ def _general_jit_named_buffers_lookaside(obj: Any, *args, **kwargs): ) +@general_jit_lookaside(torch.autograd.function.Function.apply.__func__) +def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwargs): + + custom_autograd_function_cls = unwrap(obj) + custom_forward = custom_autograd_function_cls.forward + args_, kwargs_ = tree_map(unwrap, (args, kwargs)) + ctx = torch.autograd.function.FunctionCtx() + + pr = ProvenanceRecord(PseudoInst.LOOKASIDE, inputs=[wrap_const(custom_forward).provenance]) + wrapped_ctx = wrap(ctx, provenance=pr) + args_, kwargs_ = tree_map(lambda a: wrap(a, provenance=pr), (args_, kwargs_)) + return _interpret_call(custom_forward, wrapped_ctx, *args_, **kwargs_) + + # Adds proxy methods # NOTE These methods map to themselves, which prevents the interpreter from looking into them # This is OK because these methods are written in a tracing-safe manner, and trying to diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index d63ebe7a14..7057f5212a 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -2854,6 +2854,32 @@ def forward(self, x) -> torch.Tensor: with pytest.raises(GradcheckError): gradcheck(model, (x,)) + class MyLinear(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + ctx.save_for_backward(x) + ctx.pretty_attr = 100 + return torch.matmul(x, weight.t()) + + @staticmethod + def backward(ctx, grad_output): + (x,) = ctx.saved_tensors + return torch.matmul(grad_output, weight), torch.matmul(grad_output.t(), x) + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.l1 = torch.nn.Linear(2, 2, bias=False) + + def forward(self, x): + return MyLinear.apply(x, self.l1.weight) + + x = torch.randn((2, 2), dtype=torch.float64, requires_grad=True) + model = Model().to(dtype=torch.float64) + jitted = thunder.jit(model, skip_inplace_functionalization=True) + + gradcheck(jitted, (x,)) + def test_proxy_repr(): # Verify that we can call `__repr__` on different proxy subclasses. From 6f3d0135bcb6618570b251ed354a55d417f08e6a Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Thu, 4 Jul 2024 18:26:23 +0900 Subject: [PATCH 06/10] Remove `skip_inplace_functionalization=True` from test_core (#712) --- thunder/tests/test_core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index 7057f5212a..757114fc56 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -2848,7 +2848,7 @@ def forward(self, x) -> torch.Tensor: x = torch.randn((2, 2), dtype=torch.float64, requires_grad=True) model = Model().to(dtype=torch.float64) - jitted = thunder.jit(model, skip_inplace_functionalization=True) + jitted = thunder.jit(model) gradcheck(jitted, (x,)) with pytest.raises(GradcheckError): @@ -2876,7 +2876,7 @@ def forward(self, x): x = torch.randn((2, 2), dtype=torch.float64, requires_grad=True) model = Model().to(dtype=torch.float64) - jitted = thunder.jit(model, skip_inplace_functionalization=True) + jitted = thunder.jit(model) gradcheck(jitted, (x,)) From 0dc9807c00bae6076c15a5fdc32a2653fc00e225 Mon Sep 17 00:00:00 2001 From: Kshiteej K Date: Thu, 4 Jul 2024 11:27:00 +0200 Subject: [PATCH 07/10] Update dtype repr in trace (#704) --- thunder/core/codeutils.py | 5 ++++- thunder/tests/test_core.py | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/thunder/core/codeutils.py b/thunder/core/codeutils.py index 2c74cc8053..fc7a5c7ecd 100644 --- a/thunder/core/codeutils.py +++ b/thunder/core/codeutils.py @@ -227,7 +227,10 @@ def prettyprint( unflattened_str = unflattened_str.replace(f"'{_quote_marker}", "") return unflattened_str if isinstance(x, dtypes.dtype): - return m(f"dtypes.{str(x)}") + # str(x) -> thunder.dtypes.foo + # For consistency with previous repr, + # remove `thunder.` from the representation. + return m(f"{str(x).replace('thunder.', '')}") if isinstance(x, devices.Device): return m(f'devices.Device("{x.device_str()}")') if type(x) is type: diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index 757114fc56..68312ca0d3 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -2915,3 +2915,22 @@ def fn(x): (pystr,) = tr.bound_symbols[1].python(0) assert pystr == 'result = ltorch.mul(2, x) # result: "cpu f32[2, 2]"' + + +def test_dtype_in_trace(): + def fn(x): + return x.to(torch.float16) + + jfn = thunder.jit(fn) + + x = torch.randn( + 3, + ) + + jfn(x) + + tr = thunder.last_traces(jfn)[0] + assert tr.bound_symbols[1].sym == ltorch.to + (pystr,) = tr.bound_symbols[1].subsymbols[0].python(0) + + assert "convert_element_type(x, dtypes.float16)" in pystr From e9cfc1a5b6b211e03fb0ec906344f85bf4f9ae43 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Thu, 4 Jul 2024 11:44:47 +0200 Subject: [PATCH 08/10] launch tests with cuda_launch_blocking (#708) --- .azure/gpu-tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.azure/gpu-tests.yml b/.azure/gpu-tests.yml index 3040e2fbf8..ebae7ada7a 100644 --- a/.azure/gpu-tests.yml +++ b/.azure/gpu-tests.yml @@ -79,6 +79,7 @@ jobs: - bash: | set -ex + export CUDA_LAUNCH_BLOCKING=1 coverage run --source thunder -m \ pytest thunder/tests/ \ -m "not standalone" \ From c816506d4dc61dfb6da8bc6fe4c34de8a6706399 Mon Sep 17 00:00:00 2001 From: Kshiteej K Date: Thu, 4 Jul 2024 13:16:33 +0200 Subject: [PATCH 09/10] autocast: support mixed dtypes (#705) --- thunder/__init__.py | 9 +------ thunder/core/transforms.py | 49 +++++++++++++++++++++++++++------- thunder/tests/test_autocast.py | 16 +++++++++++ thunder/torch/__init__.py | 46 +++++++++++++++++++++++++++---- 4 files changed, 97 insertions(+), 23 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index ce16cd1f79..c08b8528f4 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -391,6 +391,7 @@ def get_computation_and_inputs(*args, **kwargs): autocast_cpu_dtype=str(autocast_cpu_dtype), ) autocast_thunder_dtype = autocast_cpu_dtype if pytorch.is_autocast_cpu_enabled() else autocast_gpu_dtype + cache_info.update(autocast_thunder_dtype=str(autocast_thunder_dtype)) cache_info["is_autocast_enabled"] = is_autocast_enabled @@ -571,14 +572,6 @@ def get_computation_and_inputs(*args, **kwargs): computation_trc = dce(computation_trc) computation_traces.append(computation_trc) - if is_autocast_enabled: - from thunder.core.transforms import autocast - - computation_trc = trace(compile_data=cd)( - autocast(computation_trc.python_callable(), dtype=autocast_thunder_dtype), *inps - ) - computation_traces.append(computation_trc) - backward_trc = None if not cd.disable_torch_autograd_support: tensor_cls = (pytorch.Tensor, TensorProxy) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 84804080c1..a7bfeda19d 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -3716,6 +3716,9 @@ def backward_fn(saved_for_backward, cotangents): autocast_impls: dict[prims.PrimIDs, Callable] = {} +# NOTE: Rules which are registered ltorch symbols should match the type signature +# of those symbols as we use this rule for translating from `torch` -> `thunder.torch` +# if autocast is enabled while jitting. See also `NOTE: torch.autocast support`. def register_autocast_rule(op): def decorator(func): autocast_impls[op] = func @@ -3724,6 +3727,21 @@ def decorator(func): return decorator +_allowed_downcast_types = {dtypes.float16, dtypes.bfloat16} +_allowed_downcast_types_str_to_dtype_map = {str(dtype): dtype for dtype in _allowed_downcast_types} + + +def _check_valid_autocast_dtype(dtype): + if dtype not in _allowed_downcast_types: + raise ValueError( + f"autocast: `dtype` is expected to be either `thunder.float16` or `thunder.bfloat16`, but {dtype}" + ) + + +def _get_downcast_dtype_from_str(str_dtype): + return _allowed_downcast_types_str_to_dtype_map[str_dtype] + + def maybe_downcast_to(dtype, args): allowed_downcast_types = (dtypes.float16, dtypes.bfloat16, dtypes.float32) @@ -3742,9 +3760,7 @@ def autocast_matmul_rule(a, b, dtype): return prims.matmul(*(maybe_downcast_to(dtype, (a, b)))) -@register_autocast_rule("torch.nn.functional.linear") -@register_autocast_rule(prims.PrimIDs.LINEAR) -def autocast_linear_rule(a, w, bias, dtype): +def _linear_autocast_impl(a, w, bias, dtype): if bias is None: # Don't pass `bias` to maybe_downcast_to. downcast_args = maybe_downcast_to(dtype, (a, w)) + (bias,) @@ -3754,17 +3770,27 @@ def autocast_linear_rule(a, w, bias, dtype): return prims.linear(*downcast_args) +@register_autocast_rule("torch.nn.functional.linear") +def autocast_ltorch_linear_rule(a, w, bias=None, *, dtype): + return _linear_autocast_impl(a, w, bias, dtype) + + +@register_autocast_rule(prims.PrimIDs.LINEAR) +def autocast_linear_rule(a, w, bias, dtype): + return _linear_autocast_impl(a, w, bias, dtype) + + @register_autocast_rule("torch.nn.functional.scaled_dot_product_attention") def autocast_scaled_dot_product_attention( query, key, value, - attn_mask, - dropout_p, - is_causal, + attn_mask=None, + dropout_p=0.0, + is_causal=False, *, dtype, - scale, + scale=None, ): from thunder.torch import scaled_dot_product_attention @@ -3772,6 +3798,10 @@ def autocast_scaled_dot_product_attention( return scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal, scale=scale) +def _maybe_get_autocast_rule_for_symbol(sym: Symbol): + return autocast_impls.get(sym.id) + + def autocast_symbol_mapper(bound_symbol: BoundSymbolInterface, dtype: dtypes.dtype): """Return the callable implementing the autocast rule for the symbol. @@ -3781,7 +3811,7 @@ def autocast_symbol_mapper(bound_symbol: BoundSymbolInterface, dtype: dtypes.dty Returns: Callable: The callable implementing the autocast rule for the symbol. """ - autocast_impl: Callable | None = autocast_impls.get(bound_symbol.sym.id) + autocast_impl: Callable | None = _maybe_get_autocast_rule_for_symbol(bound_symbol.sym) return bound_symbol.sym if autocast_impl is None else partial(autocast_impl, dtype=dtype) @@ -3799,8 +3829,7 @@ def autocast(func: Callable, dtype: dtypes.dtype): if not isinstance(dtype, dtypes.dtype): raise ValueError(f"`dtype` is expected to be `thunder.dtype.dtype` but {type(dtype)}") - if dtype not in {dtypes.float16, dtypes.bfloat16}: - raise ValueError(f"`dtype` is expected to be either `thunder.float16` or `thunder.bfloat16`, but {dtype}") + _check_valid_autocast_dtype(dtype) @wraps(func) def wrapper(*args, **kwargs): diff --git a/thunder/tests/test_autocast.py b/thunder/tests/test_autocast.py index f219b5d848..9900f3975a 100644 --- a/thunder/tests/test_autocast.py +++ b/thunder/tests/test_autocast.py @@ -153,3 +153,19 @@ def fn(x, y): actual = cfn(a, b) expected = a + b torch.testing.assert_close(actual, expected) + + +def test_autocast_mixed_dtype_inputs(): + def foo(x, w): + return torch.nn.functional.linear(x, w) + + # Mixed input types. + x, w = torch.randn(16, 16, dtype=torch.bfloat16), torch.randn(16, 16) + + jfoo = thunder.jit(foo) + + with torch.autocast("cpu", torch.bfloat16): + eager_out = foo(x, w) + jit_out = jfoo(x, w) + + torch.testing.assert_close(eager_out, jit_out) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 7cc1f17ac3..6788c538da 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -41,6 +41,12 @@ from thunder.core.transforms import register_grad from thunder.core.prims import get_grad, put_grad from thunder.core.baseutils import run_once +from thunder.core.transforms import ( + _maybe_get_autocast_rule_for_symbol, + _get_downcast_dtype_from_str, + _check_valid_autocast_dtype, +) +import thunder __all__ = [ @@ -137,22 +143,52 @@ def __call__(self, fn: Callable) -> Symbol: else: sym = Symbol(name=fn.__name__, meta=_fn, id=id, is_prim=self.is_prim, tags=self.tags) + # NOTE: torch.autocast support + # In PyTorch eager, enabling autocast allows mixed inputs to operator like `linear` + # which expect dtypes to be same. This works in PyTorch eager, as dispatcher applies the + # conversion first and then passes the converted input to the operator. + # To mimick similar behavior, here we replace the `sym` for all operators which have + # autocast rule, to apply the conversion rule if autocast was enabled. + autocast_impl: Callable | None = _maybe_get_autocast_rule_for_symbol(sym) + + # `mapping_fn` is used to map `torch` -> `thunder.torch` + # If autocast is enabled - it will also take care of casting the inputs + # as per autocast rule else it will be just the symbol. + mapping_fn = sym + if autocast_impl is not None: + + @wraps(sym) + def maybe_autocast(*args, **kwargs): + # Cache info may not be available in case of legacy `thunder.compile` + # which is still used for cases. + try: + cache_info = thunder._get_cache_info() + except LookupError: + cache_info = {} + + if cache_info.get("is_autocast_enabled", False): + thunder_autocast_dtype = _get_downcast_dtype_from_str(cache_info["autocast_thunder_dtype"]) + return partial(autocast_impl, dtype=thunder_autocast_dtype)(*args, **kwargs) + return sym(*args, **kwargs) + + mapping_fn = maybe_autocast + if self.is_method: method_name: str = self.method_name if self.method_name is not None else fn.__name__ - register_method(method_name, sym) + register_method(method_name, mapping_fn) torch_method: None | Callable = getattr(torch.Tensor, method_name, None) if torch_method is not None: - _torch_to_thunder_function_map[torch_method] = sym + _torch_to_thunder_function_map[torch_method] = mapping_fn elif self.is_property: method_name: str = self.method_name if self.method_name is not None else fn.__name__ - register_property(method_name, sym) + register_property(method_name, mapping_fn) torch_property = getattr(torch.Tensor, method_name, None) if torch_property is not None: - _torch_to_thunder_function_map[torch_property] = sym + _torch_to_thunder_function_map[torch_property] = mapping_fn if self.torchfns is not None: for torchfn in self.torchfns: - _torch_to_thunder_function_map[torchfn] = sym + _torch_to_thunder_function_map[torchfn] = mapping_fn if self.tags and prims.OpTags.IN_PLACE in self.tags: _inplace_to_out_of_place[sym] = globals()[name[:-1]], -1 From 40da5bd5fabc30e99883d74b70c6a7d7fd61a828 Mon Sep 17 00:00:00 2001 From: Tom Fogal <60981+tfogal@users.noreply.github.com> Date: Thu, 4 Jul 2024 06:24:27 -0700 Subject: [PATCH 10/10] Add missing interpolate() parameters (#679) --- thunder/executors/torchex.py | 3 +++ thunder/torch/__init__.py | 9 +++++++++ 2 files changed, 12 insertions(+) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index c49ff95358..f37d2094a2 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -1453,6 +1453,9 @@ def _interpolate_checker( size: int | Sequence[int] | None = None, scale_factor: float | Sequence[float] | None = None, mode: str = "nearest", + align_corners=None, + recompute_scale_factor=None, + antialias=False, ) -> TensorLike: return 3 <= a.ndim and a.ndim <= 5 diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 6788c538da..380930f6b7 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -4579,6 +4579,9 @@ def interpolate( size: None | int | Sequence[int] = None, scale_factor: float | Sequence[float] | None = None, mode: str = "nearest", + align_corners: bool | None = None, + recompute_scale_factor: bool | None = None, + antialias: bool = False, ) -> TensorLike: utils.check( mode == "nearest", @@ -4588,6 +4591,12 @@ def interpolate( utils.check(a.ndim >= 3, lambda: f"Expected {a.ndim=} >= 3") utils.check(a.numel() > 0, lambda: f"Expected {a.numel=} to be greater than 0") + utils.check(align_corners == None, lambda: f"'align_corners' is not yet supported.") + utils.check( + recompute_scale_factor is None or recompute_scale_factor == False, + lambda: f"'recompute_scale_factor=True' is not yet supported.", + ) + utils.check(antialias == False, lambda: f"'antialias=True' is not yet supported.") utils.check( (size is not None) ^ (scale_factor is not None),