From 9691d5dee89f162004905c3079de30da04f644b7 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 19 Jun 2024 23:09:22 -0700 Subject: [PATCH 01/14] adding Static Constraint field in NumberProxy (#586) --- thunder/core/jit_ext.py | 59 ++++++++++++++++++++++++++- thunder/core/prims.py | 13 +++--- thunder/core/proxies.py | 66 +++++++++++++++++++++++++------ thunder/core/utils.py | 18 ++++++++- thunder/tests/test_jit_general.py | 52 +++++++++++++++++++++++- 5 files changed, 187 insertions(+), 21 deletions(-) diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index a82937deee..0ce1d51400 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -119,6 +119,7 @@ EXT_FLAG_IS_MODULE_MEMBER_DICT = 4 EXT_FLAG_IS_MODULE = 8 EXT_FLAG_IS_CALLABLE = 16 +EXT_FLAG_IS_CONSTRAINABLE_INPUT = 32 MODULE_MEMBER_DICT_ATTRS = { "_parameters", "_modules", @@ -618,6 +619,8 @@ def proxify(self, value: WrappedValue) -> Any: value.provenance.ext_flag |= EXT_FLAG_IS_PROXY_DERIVED # we follow the caching mechanisms of the eager_unpack_interpreter p = proxy(uvalue, history=value.provenance) + if value.provenance.ext_flag & EXT_FLAG_IS_CONSTRAINABLE_INPUT and hasattr(p, "make_constrainable"): + p.make_constrainable() assert p.history is not None, f"{p.history}, {value.provenance} {type(p)}" co: CACHE_OPTIONS = get_cache_option() @@ -815,6 +818,11 @@ def _general_jit_hasattr_lookaside(obj: Any, name: str): # recording the constraint to conditional jumps and such. def _general_jit_bool_lookaside(wrapped_x: Any) -> bool | INTERPRETER_SIGNALS: assert isinstance(wrapped_x, WrappedValue) + # It doesn't feel right to insert constraints in bool lookaside, constraints here only applies when the bool value is used in control flow. + if isinstance(wrapped_x.value, NumberProxy): + if wrapped_x.value.is_dynamic(): + raise NotImplementedError(f"conversion to bool is not allowed on dynamic proxy={wrapped_x.value}") + wrapped_x.value.make_static_constrained() bool_lookaside = default_lookaside(bool) or bool return bool_lookaside(wrapped_x) @@ -1134,6 +1142,7 @@ def _general_jit_wrap_callback(value): pass elif should_register_for_prologue(value.provenance): value.provenance.ext_flag |= EXT_FLAG_IS_PROXY_DERIVED + value.provenance.ext_flag |= EXT_FLAG_IS_CONSTRAINABLE_INPUT # we follow the caching mechanisms of the eager_unpack_interpreter p = ctx.proxify(value) else: @@ -1179,6 +1188,52 @@ def _general_jit_store_deref_callback( general_jit_callbacks = default_callbacks | general_jit_callbacks +# This pass identifies NumberProxy that's marked as statically constrained and propagate the constraints to inputs to the trace. +# The logic is that, if all inputs that produces a NumberProxy is marked statically constrained, then the value of the NumberProxy is statically constrained. +# This pass currently only does backward propagation to insert constraints in prologue trace +# TODO: We should be able to apply constant-folding and simplify computation_trace. +# TODO: If we allow symbolic constraints, we would be able to get more cache re-use. i.e. rather than requiring a NumberProxy to be static, we can have a finer grained constraints as `check_number_gt`. +def propagate_constraints(ctx, inputs, intermediates, computation_trace): + import thunder.core.utils as utils + + # set of NumberProxy variables that has already been traversed and marked as statically constrained. + static_np_set = set() + + # add static constraints for inputs + for inp in inputs: + u_inp = unvariableify(inp) + if not isinstance(u_inp, NumberProxy): + continue + if u_inp.is_static_constrained(): + ctx.add_constraint((clang.check_number_type_and_value, u_inp, u_inp.value)) + static_np_set.add(inp) + + producers = utils.producers(computation_trace.bound_symbols, _map_to_numbers=False) + # add static constraints propagated from intermediates. + for intermediate in intermediates: + u_intermediate = unvariableify(intermediate) + if not isinstance(u_intermediate, NumberProxy) or not u_intermediate.is_static_constrained(): + continue + + # DFS traversal along producers, starting from seed `intermediate` + front = [intermediate] + while len(front) != 0: + v = front.pop() + if v in static_np_set: + continue + static_np_set.add(v) + + uv = unvariableify(v) + if v in inputs: + ctx.add_constraint((clang.check_number_type_and_value, uv, uv.value)) + else: + producer = producers[uv] + for inp in producer.flat_proxy_args: + if not isinstance(inp, NumberProxy): + continue + front.append(variableify(inp)) + + def get_computation_inputs_and_intermediates(computation_trace): inputs_list = [] inputs_set = set() @@ -1609,12 +1664,14 @@ def thunder_general_jit( last_interpreter_log = jfn._last_interpreter_log pro_to_comp, computation_intermediates = get_computation_inputs_and_intermediates(computation_trace) - epilogue_inputs, _ = get_computation_inputs_and_intermediates(epilogue_trace) comp_to_epi = [] pro_to_epi = [] + # propagate static constrained intermediates to inputs + propagate_constraints(ctx, pro_to_comp, computation_intermediates, computation_trace) + for i in epilogue_inputs: if i in computation_intermediates: comp_to_epi.append(i) diff --git a/thunder/core/prims.py b/thunder/core/prims.py index eef31b0bff..21f42c25c6 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -63,6 +63,7 @@ def register_method(method_name: str, method: Callable, /) -> None: from thunder.core.symbol import Symbol, BoundSymbol, default_python_printer from thunder.core.proxies import ( + CONSTRAINT, CollectionProxy, TensorProxy, NumberProxy, @@ -1717,7 +1718,7 @@ def _convert_element_type_meta(a: Number | TensorProxy, /, dtype: type | dtypes. utils.check(utils.is_numbertype(dtype), lambda: f"Trying to convert a number to non-numbertype object {dtype}") if isinstance(a, NumberProxy): - return numberproxy(dtype, dtype(utils.get_numberlike_value(a))) + return numberproxy(dtype, dtype(utils.get_numberlike_value(a)), constraint=a.constraint) number_result = dtype(a) return number_result @@ -1830,7 +1831,7 @@ def meta(a: Number | TensorProxy, /) -> Number | TensorProxy: isinstance(a, NumberProxy), lambda: f"Trying to call an elementwise unary operation {name} on a number, but the operation is not eagerly defined", ) - return numberproxy(output_type, None) + return numberproxy(output_type, None, a.constraint) # need to cast val to python_type in order to properly propagate output dtype. value = number_fn(typ(val)) @@ -1840,8 +1841,8 @@ def meta(a: Number | TensorProxy, /) -> Number | TensorProxy: ) # Only returns a proxy if the input is a proxy - if isinstance(a, Proxy): - return numberproxy(type(value), value) + if isinstance(a, NumberProxy): + return numberproxy(type(value), value, a.constraint) return value # NOTE a is a TensorProxy @@ -2246,12 +2247,12 @@ def meta( isinstance(a, NumberProxy) or isinstance(b, NumberProxy), lambda: f"Trying to call an elementwise binary operation {name} on two numbers, but the operation is not eagerly defined", ) - return numberproxy(numbertype, None) + return numberproxy(numbertype, None, constraint=utils.resolve_constraints(a, b)) value = number_fn(aval, bval) # Only returns a NumberProxy if at least one input is a number proxy if isinstance(a, NumberProxy) or isinstance(b, NumberProxy): - return numberproxy(type(value), value) + return numberproxy(type(value), value, constraint=utils.resolve_constraints(a, b)) return value else: diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index 04d95c59b1..12885c794e 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -567,14 +567,35 @@ def setdefault(self, *args): raise NotImplementedError("Calling setdefault on an input dict is not yet supported") +# CONSTRAINT annotates NumberProxy as their get processed by interpreter. +# A NumberProxy can be at one of the status: +# - A DYNAMIC NumberProxy cannot be converted to a static number; +# - A CONSTRAINABLE NumberProxy is treated as DYNAMIC by default, but it could be converted to STATIC by interpreter; +# - A STATIC NumberProxy can be treated as a static number, but not necessarily so; +# The protocol here is that, if a NumberProxy instance is converted to static, we'll insert a guard logic in prologue trace to ensure the NumberProxy doesn't change at runtime. +class CONSTRAINT(Enum): + DYNAMIC = auto() + CONSTRAINABLE = auto() + STATIC = auto() + + # NOTE NumberProxies are NOT Numbers # TODO Maybe NumberProxies should be Numbers? class NumberProxy(Proxy, NumberProxyInterface): def __init__( - self, name: str | None = None, value: Number | None = None, *, python_type: type, history: None | tuple = None + self, + name: str | None = None, + value: Number | None = None, + *, + python_type: type, + history: None | tuple = None, + constraint: None | CONSTRAINT = None, ): self.value = value self.python_type = python_type + if constraint is None: + constraint = CONSTRAINT.DYNAMIC + self.constraint = constraint Proxy.__init__(self, name, history=history) @@ -589,6 +610,19 @@ def replace_name(self, name: str, /): def known_value(self) -> bool: return self.value is not None + def make_static_constrained(self): + baseutils.check(self.constraint != CONSTRAINT.DYNAMIC, lambda: f"dynamic NumberProxy cannot be made static") + self.constraint = CONSTRAINT.STATIC + + def make_constrainable(self): + self.constraint = CONSTRAINT.CONSTRAINABLE + + def is_static_constrained(self) -> bool: + return self.constraint == CONSTRAINT.STATIC + + def is_dynamic(self) -> bool: + return self.constraint == CONSTRAINT.DYNAMIC + # # Elementwise unary operators # @@ -944,8 +978,8 @@ def pytype(x: Proxy) -> type | None: # TODO RC1 Update Proxy number inits to be value, /, *, name, history class ComplexProxy(NumberProxy): - def __init__(self, name=None, value=None, history: None | tuple = None): - NumberProxy.__init__(self, name=name, value=value, python_type=complex, history=history) + def __init__(self, name=None, value=None, history: None | tuple = None, constraint: None | CONSTRAINT = None): + NumberProxy.__init__(self, name=name, value=value, python_type=complex, history=history, constraint=constraint) def replace_name(self, name): """Return a copy of this proxy with the given name.""" @@ -959,10 +993,18 @@ def type_string(self): # TODO Review dtype conversions # TODO Review -9999 as the marker value for unknown values class IntegerProxy(NumberProxy): - def __init__(self, name: str | None = None, value=None, history: None | tuple = None): + def __init__( + self, + name: str | None = None, + value=None, + history: None | tuple = None, + constraint: None | CONSTRAINT = None, + ): # NOTE bools are also integers in Python python_type = bool if isinstance(value, bool) else int - NumberProxy.__init__(self, name=name, value=value, python_type=python_type, history=history) + NumberProxy.__init__( + self, name=name, value=value, python_type=python_type, history=history, constraint=constraint + ) def replace_name(self, name): """Return a copy of this proxy with the given name.""" @@ -975,8 +1017,8 @@ def type_string(self): def __repr__(self): if self.python_type is bool: - return f"[IntegerProxy (bool type) name={self.name}, value={self.value}]" - return f"[IntegerProxy name={self.name}, value={self.value}]" + return f"[IntegerProxy (bool type) name={self.name}, value={self.value}, static={self.constraint}]" + return f"[IntegerProxy name={self.name}, value={self.value}, static={self.constraint}]" def __index__(self): return self.value @@ -984,8 +1026,8 @@ def __index__(self): # TODO Review dtype conversions class FloatProxy(NumberProxy): - def __init__(self, name=None, value=None, history: None | tuple = None): - NumberProxy.__init__(self, name=name, value=value, python_type=float, history=history) + def __init__(self, name=None, value=None, history: None | tuple = None, constraint: None | CONSTRAINT = None): + NumberProxy.__init__(self, name=name, value=value, python_type=float, history=history, constraint=constraint) def replace_name(self, name): """Return a copy of this proxy with the given name.""" @@ -996,7 +1038,7 @@ def type_string(self): return f"float {value_str}" def __repr__(self): - return f"[FloatProxy name={self.name}, value={self.value}]" + return f"[FloatProxy name={self.name}, value={self.value}, static={self.constraint}]" class DistParallelType(Enum): @@ -1565,9 +1607,9 @@ def futuretensorproxy( ) -def numberproxy(cls: type, value: Number | None) -> NumberProxy: +def numberproxy(cls: type, value: Number | None, constraint: None | CONSTRAINT = None) -> NumberProxy: pcls = _cls_to_number_proxy_map[cls] - return pcls(value=value) + return pcls(value=value, constraint=constraint) # TODO RC1 Remove this function diff --git a/thunder/core/utils.py b/thunder/core/utils.py index 2996d72286..a094fc6548 100644 --- a/thunder/core/utils.py +++ b/thunder/core/utils.py @@ -13,7 +13,7 @@ import thunder.core.dtypes as dtypes from thunder.core.pytree import tree_flatten, tree_unflatten, tree_map -from thunder.core.proxies import Proxy, NumberProxy, TensorProxy, variableify +from thunder.core.proxies import Proxy, NumberProxy, TensorProxy, variableify, CONSTRAINT from thunder.core.baseutils import * from thunder.core.codeutils import * from thunder.core.trace import TraceCtx @@ -131,6 +131,22 @@ def debug_asserts_level() -> int: corresponding_complex_dtype = dtypes.corresponding_complex_dtype +# This function resolves the CONSTRAINT tag from args, by looking at each Proxy instance in args: +# TODO: we currently only considers NumberProxy could be statically constrained. This is likely going to be extended to other proxies in the future. +def resolve_constraints(*args): + all_static = True + for arg in args: + if not isinstance(arg, Proxy): + continue + if not isinstance(arg, NumberProxy) or arg.is_dynamic(): + return CONSTRAINT.DYNAMIC + if not arg.is_static_constrained(): + all_static = False + if all_static: + return CONSTRAINT.STATIC + return CONSTRAINT.CONSTRAINABLE + + def higher_dtype(a, b): for fn in ( is_complex_dtype, diff --git a/thunder/tests/test_jit_general.py b/thunder/tests/test_jit_general.py index 507dc5ba15..da7fe97c47 100644 --- a/thunder/tests/test_jit_general.py +++ b/thunder/tests/test_jit_general.py @@ -809,7 +809,7 @@ def test_tom_overrides_proxy(device): "device", ("cpu", "cuda"), ) -def test_cache_symbolic_values(device): +def test_cache_symbolic_values_basic(device): if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA not available") @@ -905,3 +905,53 @@ def foo(x, device): with ctx: actual_device = jfoo(x, expected_device).device assert actual_device == expected_device + + +def test_cache_symbolic_values_constraints(): + def foo(scalar): + if scalar > 0: + return scalar + return 0 + + jfoo = thunder.jit(foo, cache="symbolic values") + + expected = foo(1.5) + actual = jfoo(1.5) + + assert_close(expected, actual) + assert thunder.cache_misses(jfoo) == 1 + assert thunder.cache_hits(jfoo) == 0 + + expected = foo(2.0) + actual = jfoo(2.0) + + assert_close(expected, actual) + # even though we should be able to re-use the cache, we cannot do it now. Because constraints are propagated to inputs being static number. + assert thunder.cache_misses(jfoo) == 2 + assert thunder.cache_hits(jfoo) == 0 + + expected = foo(1.5) + actual = jfoo(1.5) + + assert_close(expected, actual) + assert thunder.cache_misses(jfoo) == 2 + assert thunder.cache_hits(jfoo) == 1 + + expected = foo(-0.3) + actual = jfoo(-0.3) + + assert_close(expected, actual) + assert thunder.cache_misses(jfoo) == 3 + assert thunder.cache_hits(jfoo) == 1 + + def bar(t): + if t[0].item() > 5: + return t + 1.0 + return t + + with pytest.raises( + thunder.core.interpreter.InterpreterError, match="conversion to bool is not allowed on dynamic proxy" + ): + jbar = thunder.jit(bar, cache="symbolic values") + t = torch.randn(4, device="cpu") + jbar(t) From 5d18fce56a2e4c18ef04306c7c7b043cac882dac Mon Sep 17 00:00:00 2001 From: Vedaanta Agarwalla <142048820+vedaanta@users.noreply.github.com> Date: Wed, 19 Jun 2024 23:12:16 -0700 Subject: [PATCH 02/14] Bumps cudnn FE to v1.5 (#593) --- .azure/docker-build.yml | 8 ++++---- .azure/gpu-tests.yml | 8 ++++---- .azure/notebook-runs.yml | 4 ++-- dockers/ubuntu-cuda/Dockerfile | 2 +- thunder/executors/cudnnex.py | 7 +++++-- 5 files changed, 16 insertions(+), 13 deletions(-) diff --git a/.azure/docker-build.yml b/.azure/docker-build.yml index a33f4a2f58..df8a8874df 100644 --- a/.azure/docker-build.yml +++ b/.azure/docker-build.yml @@ -40,10 +40,10 @@ jobs: #maxParallel: "3" matrix: # CUDA 12.1 - "cuda 12.1 | torch 2.3 | cudnn FE v1.4": - { CUDA_VERSION: "12.1.1", TORCH_VERSION: "2.3.0", TRITON_VERSION: "2.3.0", CUDNN_FRONTEND_VERSION: "1.4.0" } - "cuda 12.1 | torch 2.4 /nightly | cudnn FE v1.4": - { CUDA_VERSION: "12.1.1", TORCH_VERSION: "main", TORCH_INSTALL: "source", CUDNN_FRONTEND_VERSION: "1.4.0" } + "cuda 12.1 | torch 2.3 | cudnn FE v1.5.1": + { CUDA_VERSION: "12.1.1", TORCH_VERSION: "2.3.0", TRITON_VERSION: "2.3.0", CUDNN_FRONTEND_VERSION: "1.5.1" } + "cuda 12.1 | torch 2.4 /nightly | cudnn FE v1.5.1": + { CUDA_VERSION: "12.1.1", TORCH_VERSION: "main", TORCH_INSTALL: "source", CUDNN_FRONTEND_VERSION: "1.5.1" } #'cuda 12.1': # this version - '8.9.5.29-1+cuda12.1' for 'libcudnn8' was not found # how much time to give 'run always even if cancelled tasks' before stopping them cancelTimeoutInMinutes: "2" diff --git a/.azure/gpu-tests.yml b/.azure/gpu-tests.yml index 6ba3398ebd..0b8bfdb348 100644 --- a/.azure/gpu-tests.yml +++ b/.azure/gpu-tests.yml @@ -17,17 +17,17 @@ jobs: matrix: # CUDA 12.1 "ubuntu22.04 | cuda 12.1 | python 3.10 | torch 2.3 | regular": - docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.4.0-py3.10-pt_2.3.0-dev" + docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.5.1-py3.10-pt_2.3.0-dev" CUDA_VERSION_MM: "121" "ubuntu22.04 | cuda 12.1 | python 3.10 | torch 2.3 | distributed": - docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.4.0-py3.10-pt_2.3.0-dev" + docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.5.1-py3.10-pt_2.3.0-dev" CUDA_VERSION_MM: "121" testing: "distributed" "ubuntu22.04 | cuda 12.1 | python 3.10 | torch-nightly | regular": - docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.4.0-py3.10-pt_main-dev" + docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.5.1-py3.10-pt_main-dev" CUDA_VERSION_MM: "121" "ubuntu22.04 | cuda 12.1 | python 3.10 | torch-nightly | distributed": - docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.4.0-py3.10-pt_main-dev" + docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.5.1-py3.10-pt_main-dev" CUDA_VERSION_MM: "121" testing: "distributed" # how much time to give 'run always even if cancelled tasks' before stopping them diff --git a/.azure/notebook-runs.yml b/.azure/notebook-runs.yml index 5181fba015..bfc326fb4c 100644 --- a/.azure/notebook-runs.yml +++ b/.azure/notebook-runs.yml @@ -16,10 +16,10 @@ jobs: strategy: matrix: "ubuntu22.04 | cuda 12.1 | torch 2.3": - docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.4.0-py3.10-pt_2.3.0-dev" + docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.5.1-py3.10-pt_2.3.0-dev" CUDA_VERSION_MM: "121" "ubuntu22.04 | cuda 12.1 | torch-nightly": - docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.4.0-py3.10-pt_main-dev" + docker-image: "ubuntu22.04-cuda12.1.1-cudnn-fe1.5.1-py3.10-pt_main-dev" CUDA_VERSION_MM: "121" # how long to run the job before automatically cancelling timeoutInMinutes: "45" diff --git a/dockers/ubuntu-cuda/Dockerfile b/dockers/ubuntu-cuda/Dockerfile index 6da7b4418f..d6cd7c8e6f 100644 --- a/dockers/ubuntu-cuda/Dockerfile +++ b/dockers/ubuntu-cuda/Dockerfile @@ -20,7 +20,7 @@ ARG IMAGE_TYPE="devel" FROM nvidia/cuda:${CUDA_VERSION}-${IMAGE_TYPE}-ubuntu${UBUNTU_VERSION} ARG CUDNN_VERSION="9.1.0.70" -ARG CUDNN_FRONTEND_VERSION="1.4.0" +ARG CUDNN_FRONTEND_VERSION="1.5.1" ARG PYTHON_VERSION="3.10" ARG TORCH_VERSION="2.2.1" ARG TRITON_VERSION="2.2.0" diff --git a/thunder/executors/cudnnex.py b/thunder/executors/cudnnex.py index 38aa997169..e2a67e0b94 100644 --- a/thunder/executors/cudnnex.py +++ b/thunder/executors/cudnnex.py @@ -28,8 +28,11 @@ def cudnn_version() -> LooseVersion | None: def required_cudnn_version() -> LooseVersion: - # Using 1.3.0 majorly because it works better with other libraries (e.g. torch) that also build on top of cudnn backend - return LooseVersion("1.3.0") + # History of versions: + # Using 1.3.0+ because it works better with other libraries (e.g. torch) that also build on top of cudnn + # Using 1.5.0+ because it handles exception with unsupported graphs better + # Using 1.5.1 because of a compatibility fix + return LooseVersion("1.5.1") def cudnn_available() -> bool: From 1f4590a896861f41605f6cefd033b57365893447 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Thu, 20 Jun 2024 12:39:43 +0200 Subject: [PATCH 03/14] Fix issue template (#627) --- .github/ISSUE_TEMPLATE/program_coverage.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/program_coverage.md b/.github/ISSUE_TEMPLATE/program_coverage.md index dd733dfc99..a16df4058c 100644 --- a/.github/ISSUE_TEMPLATE/program_coverage.md +++ b/.github/ISSUE_TEMPLATE/program_coverage.md @@ -1,6 +1,6 @@ --- -name: Feature request -about: Suggest an idea for this project +name: Program Coverage +about: Expand the programs / models Thunder can process title: '' labels: program-coverage assignees: '' From e28ea5edfc69554bb2ee8725e5e50f5a2b0f6b41 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Thu, 20 Jun 2024 22:26:54 +0900 Subject: [PATCH 04/14] Functionalize in-place ops (#584) --- thunder/__init__.py | 11 +- thunder/core/langctxs.py | 2 +- thunder/core/prims.py | 1 + thunder/core/proxies.py | 35 ++ thunder/core/transform_common.py | 115 ++++- thunder/tests/opinfos.py | 30 -- thunder/tests/test_core.py | 3 +- thunder/tests/test_inplace_copy.py | 5 +- .../tests/test_inplace_functionalization.py | 124 +++++ thunder/torch/__init__.py | 440 +++++++++++++++++- 10 files changed, 714 insertions(+), 52 deletions(-) create mode 100644 thunder/tests/test_inplace_functionalization.py diff --git a/thunder/__init__.py b/thunder/__init__.py index d48ddd9f51..e647a42aa0 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -33,7 +33,13 @@ import thunder.core.prims as prims import thunder.core.dtypes as dtypes import thunder.core.devices as devices -from thunder.core.transform_common import dce, EarlyTransform, AdditionalTransform, PostOptimizationTransform +from thunder.core.transform_common import ( + dce, + EarlyTransform, + AdditionalTransform, + PostOptimizationTransform, + functionalize_inplace_ops, +) from thunder.common import ( CompileData, CompileStats, @@ -503,6 +509,9 @@ def get_computation_and_inputs(*args, **kwargs): prologue_traces = [prologue_trc] computation_traces = [computation_trc] + if not compile_options.get("skip_inplace_functionalization", False): + computation_traces.extend(functionalize_inplace_ops(computation_trace=computation_trc)) + computation_trc = computation_traces[-1] if epilogue_trc is not None: epilogue_traces = [epilogue_trc] diff --git a/thunder/core/langctxs.py b/thunder/core/langctxs.py index 5360ea6bc1..1853b8ecfc 100644 --- a/thunder/core/langctxs.py +++ b/thunder/core/langctxs.py @@ -72,7 +72,7 @@ def resolve_method(id: Any, *args, **kwargs) -> None | Callable: # ctx.get_method throws an AttributeError when the context does not have the requested attribute, except # for the prims language context, which always throws a ValueError method: Callable = ctx.get_method(id, *args, **kwargs) - except (AttributeError, ValueError) as e: + except (AttributeError, ValueError): return None return method diff --git a/thunder/core/prims.py b/thunder/core/prims.py index 21f42c25c6..c0c11ab6f3 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -275,6 +275,7 @@ class OpTags(Enum): DEVICE_SYNC_OP = auto() # Labels operations that should not be removed by the dead code elimination (DCE) pass DONT_DCE = auto() + IN_PLACE = auto() # TODO RC1 Document this function and describe the parts of a primitive diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index 12885c794e..786163cbb5 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -1391,6 +1391,13 @@ def __add__(self, other): method = resolve_method("add", self, other) return method(self, other) + def __iadd__(self, other): + return self.add_(other) + + def add_(self, other): + method = resolve_method("add_", self, other) + return method(self, other) + def __radd__(self, other): method = resolve_method("add", other, self) return method(other, self) @@ -1427,6 +1434,13 @@ def __mul__(self, other): method = resolve_method("mul", self, other) return method(self, other) + def __imul__(self, other): + return self.mul_(other) + + def mul_(self, other): + method = resolve_method("mul_", self, other) + return method(self, other) + def __rmul__(self, other): method = resolve_method("mul", other, self) return method(other, self) @@ -1435,6 +1449,13 @@ def __pow__(self, other): method = resolve_method("pow", self, other) return method(self, other) + def __ipow__(self, other): + return self.pow_(other) + + def pow_(self, other): + method = resolve_method("pow_", self, other) + return method(self, other) + def __rpow__(self, other): method = resolve_method("pow", other, self) return method(other, self) @@ -1443,6 +1464,13 @@ def __sub__(self, other): method = resolve_method("sub", self, other) return method(self, other) + def __isub__(self, other): + return self.sub_(other) + + def sub_(self, other): + method = resolve_method("sub_", self, other) + return method(self, other) + def __rsub__(self, other): method = resolve_method("sub", other, self) return method(other, self) @@ -1455,6 +1483,13 @@ def __rtruediv__(self, other): method = resolve_method("true_divide", other, self) return method(other, self) + def __itruediv__(self, other): + return self.div_(other) + + def div_(self, other, *, rounding_mode: str | None = None): + method = resolve_method("div_", self, other, rounding_mode=rounding_mode) + return method(self, other) + # # Logical operations # diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index 7f8486f9e1..4b71d14b45 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -1,5 +1,6 @@ +from __future__ import annotations import time -from typing import Any +from typing import TYPE_CHECKING from abc import ABC, abstractmethod from collections.abc import Sequence from itertools import filterfalse @@ -7,12 +8,16 @@ import thunder.core.prims as prims from thunder.core.baseutils import BoundSymbolInterface -from thunder.core.proxies import Proxy, variableify, Variable -from thunder.core.pytree import tree_flatten, tree_map +from thunder.core.proxies import Proxy, variableify, Variable, TensorProxy +from thunder.core.pytree import tree_flatten, tree_map, tree_unflatten from thunder.core.symbol import BoundSymbol, BoundSymbolRHS, has_tags from thunder.core.trace import from_trace, TraceProvenance, TraceCtx as Trace from thunder.core.utils import ProxyDict, producers, check +if TYPE_CHECKING: + from thunder.core.proxies import ProxyInterface + from thunder.core.symbol import Symbol, VariableInterface + # # Common optimization and transform passes @@ -363,3 +368,107 @@ class PostOptimizationTransform(Transform, ABC): @abstractmethod def transform_trace(self, computation_trace: Trace, **kwargs): pass + + +def functionalize_inplace_ops(computation_trace: Trace) -> list[Trace]: + """Functionalize in-place ops in ``computation_trace``. + + In thunder, an in-place is an out-of-place or functional op followed by :func:`~thunder.core.prims.copy_`. + This function replaces such in-place ops with out-of-place ops. + Note that functionalization is not applied, if any of an in-place op's arguments are + ``computation_trace.args`` or ``computation_trace.kwargs``. + + For example, :func:`thunder.torch.add_` is represented as a :class:`thunder.core.symbol.BoundSymbol` + whose `subsymbols` are :func:`thunder.torch.add` and :func:`thunder.core.prims.copy_`. This function + replaces it with a :class:`~thunder.core.symbol.BoundSymbol` of :func:`~thunder.torch.add`. + """ + import thunder.torch + + def is_functionalizable(bsym: BoundSymbol) -> bool: + """Has `OpTags.IN_PLACE` and its args are NOT ``computation_trace.args`` nor ``computation_trace.kwargs``.""" + return ( + bsym.sym in thunder.torch._inplace_to_out_of_place + and bsym.subsymbols + and bsym.subsymbols[-1].sym.id == prims.PrimIDs.COPY_ + ) + + if not any(is_functionalizable(bsym) for bsym in computation_trace.bound_symbols): + return [] + + # Step 1: return the tensors returned from `prims.copy_` as possible not the args for clarity. + bsym: BoundSymbol + swap_map: dict[VariableInterface, ProxyInterface] = {} + bsyms: list[BoundSymbol] = [] + for bsym in computation_trace.bound_symbols: + new_bsym = bsym.from_bsym_swap_proxies(swap_map) + + # in-place functionalizable ops has `prims.copy_` as the last subsymbol. + if not is_functionalizable(new_bsym): + bsyms.append(new_bsym) + continue + + copy_bsym = bsym.subsymbols[-1] + copy_out = copy_bsym.flat_proxy_outs[0] + copy_dst = copy_bsym.flat_proxy_args[1] + swap_map[variableify(copy_dst)] = copy_out + # make sure an in-place bsym returns `prims.copy_` output + new_bsym = new_bsym.from_bsym_swap_proxies(swap_map, skip_inputs=True, skip_subsymbols=True) + bsyms.append(new_bsym) + + intermediate_trace = from_trace(computation_trace) + intermediate_trace.bound_symbols = bsyms[:] + intermediate_trace.set_provenance(TraceProvenance("Intermediate trace of `functionalize_inplace_ops`")) + del bsyms + + # Step 2: Remove `prims.copy_` if it's the last one of `bsym.subsymbols`, + # unless `copy_to` is `computation_trace.args` or `computation_trace.kwargs` + trace_args_set = ProxyDict() + for a in filter( + lambda a: isinstance(a, TensorProxy), tree_flatten((computation_trace.args, computation_trace.kwargs))[0] + ): + trace_args_set[a] = a + bsym_inplace_to_functional = {} + swap_map.clear() + new_bsyms: list[BoundSymbol] = [] + for bsym in intermediate_trace.bound_symbols: + new_bsym = bsym.from_bsym_swap_proxies(swap_map) + + if not is_functionalizable(new_bsym): + new_bsyms.append(new_bsym) + continue + copy_bsym = bsym.subsymbols[-1] + copy_return = copy_bsym.flat_proxy_outs[0] + copy_from = copy_bsym.flat_proxy_args[0] + copy_to = copy_bsym.flat_proxy_args[1] + if copy_to in trace_args_set: + new_bsyms.append(new_bsym) + else: + swap_map[variableify(copy_return)] = copy_from + new_bsym.subsymbols = new_bsym.subsymbols[:-1] + new_bsym = new_bsym.from_bsym_swap_proxies(swap_map) + + functional_sym: Symbol + optional_inplace_arg_index: int + functional_sym, optional_inplace_arg_index = thunder.torch._inplace_to_out_of_place[new_bsym.sym] + + flat_args, flat_args_spec = tree_flatten((new_bsym.args, new_bsym.kwargs)) + if optional_inplace_arg_index > -1: + flat_args[optional_inplace_arg_index] = False + args, kwargs = tree_unflatten(flat_args, flat_args_spec) + new_functional_bsym = functional_sym.bind( + *args, + **kwargs, + output=new_bsym.output, + subsymbols=new_bsym.subsymbols, + _call_ctx=new_bsym._call_ctx, + ) + new_bsyms.append(new_functional_bsym) + bsym_inplace_to_functional[new_bsym] = new_functional_bsym + + functionalized_computation_trace = from_trace(computation_trace) + functionalized_computation_trace.bound_symbols = new_bsyms + functionalized_computation_trace.set_provenance(TraceProvenance("Functionalize in-place ops")) + # note(crcrpar): I kind of want to do the following two. + # functionalized_computation_trace._provenance.swap_map = swap_map + # functionalized_computation_trace._provenance.bsym_inplace_to_functional = bsym_inplace_to_functional + return [intermediate_trace, functionalized_computation_trace] diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 4d98a5a84d..66c15e6ce1 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -1278,16 +1278,10 @@ def _abs_torch(x: torch.Tensor | Number): elementwise_unary_ops.append(signbit_opinfo) -def silu_error_generator(op, device, dtype=torch.float32, **kwargs): - a = make_tensor((), dtype=dtype, device=device) - yield (SampleInput(a, inplace=True), NotImplementedError, "Thunder only supports silu with inplace=False") - - silu_opinfo = OpInfo( ltorch.silu, dtypes=(datatypes.floating,), sample_input_generator=partial(elementwise_unary_generator, supports_numbers=False), - error_input_generator=silu_error_generator, torch_reference=_elementwise_unary_torch(torch.nn.functional.silu), test_directives=( DecorateInfo( @@ -1623,20 +1617,9 @@ def silu_error_generator(op, device, dtype=torch.float32, **kwargs): elementwise_unary_ops.append(reciprocal_opinfo) -def relu_error_generator(op, device, dtype=torch.float32, **kwargs): - a = make_tensor((), dtype=dtype, device=device) - yield (SampleInput(a, inplace=True), NotImplementedError, "relu only supports inplace=False") - - -def relu6_error_generator(op, device, dtype=torch.float32, **kwargs): - a = make_tensor((), dtype=dtype, device=device) - yield (SampleInput(a, inplace=True), NotImplementedError, "relu6 only supports inplace=False") - - relu_opinfo = OpInfo( ltorch.relu, sample_input_generator=elementwise_unary_generator, - error_input_generator=relu_error_generator, torch_reference=_elementwise_unary_torch(torch.relu), test_directives=( # PyTorch does not support bool and complex types @@ -1665,7 +1648,6 @@ def relu6_error_generator(op, device, dtype=torch.float32, **kwargs): relu6_opinfo = OpInfo( ltorch.relu6, sample_input_generator=elementwise_unary_generator, - error_input_generator=relu6_error_generator, torch_reference=_elementwise_unary_torch(torch.nn.functional.relu6), test_directives=( # PyTorch does not support bool for both CPU and CUDA relu6 @@ -1684,15 +1666,9 @@ def relu6_error_generator(op, device, dtype=torch.float32, **kwargs): elementwise_unary_ops.append(relu6_opinfo) -def hardswish_error_generator(op, device, dtype=torch.float32, **kwargs): - a = make_tensor((), dtype=dtype, device=device) - yield (SampleInput(a, inplace=True), NotImplementedError, "hardswish only supports inplace=False") - - hardswish_opinfo = OpInfo( ltorch.hardswish, sample_input_generator=elementwise_unary_generator, - error_input_generator=hardswish_error_generator, torch_reference=_elementwise_unary_torch(torch.nn.functional.hardswish), dtypes=(datatypes.floating,), test_directives=( @@ -1713,16 +1689,10 @@ def hardswish_error_generator(op, device, dtype=torch.float32, **kwargs): elementwise_unary_ops.append(hardswish_opinfo) -def selu_error_generator(op, device, dtype=torch.float32, **kwargs): - a = make_tensor((), dtype=dtype, device=device) - yield (SampleInput(a, inplace=True), NotImplementedError, "selu only supports inplace=False") - - selu_opinfo = OpInfo( ltorch.selu, dtypes=(datatypes.floating,), sample_input_generator=elementwise_unary_generator, - error_input_generator=selu_error_generator, torch_reference=_elementwise_unary_torch(torch.selu), test_directives=( # Some versions of PyTorch do not support CPU float16 selu diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index fffc6098d1..8c8390f916 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -2121,7 +2121,8 @@ def test_xor(s, o): for t in tests: cfn = thunder.jit(t) - with pytest.raises(RuntimeError, match="not supported"): + # Some ops of `tests` already have in-place supported, leading to broadcast error + with pytest.raises(RuntimeError, match="not supported|Attempting"): cfn(t1, t2) # Note: Python maps inplace operations on (immutuables) to # out of place operations, NumberProxy does this, too. diff --git a/thunder/tests/test_inplace_copy.py b/thunder/tests/test_inplace_copy.py index c7121cea5b..fd7adc9230 100644 --- a/thunder/tests/test_inplace_copy.py +++ b/thunder/tests/test_inplace_copy.py @@ -89,8 +89,6 @@ class Net(nn.Module): def __init__(self): super().__init__() self.dense1_bn = nn.BatchNorm3d(2, track_running_stats=True) - # To address the failure, use a workaround since `add_` is utilized in `nn.BatchNorm3d` when `num_batches_tracked` is not None. - self.dense1_bn.num_batches_tracked = None def forward(self, x): x = self.dense1_bn(x) @@ -112,6 +110,9 @@ def forward(self, x): assert_close(thunder_out, torch_out) assert_close(net.state_dict()["dense1_bn.running_mean"], torch_net.state_dict()["dense1_bn.running_mean"]) assert_close(net.state_dict()["dense1_bn.running_var"], torch_net.state_dict()["dense1_bn.running_var"]) + assert_close( + net.state_dict()["dense1_bn.num_batches_tracked"], torch_net.state_dict()["dense1_bn.num_batches_tracked"] + ) assert_close(x.grad, x1.grad) diff --git a/thunder/tests/test_inplace_functionalization.py b/thunder/tests/test_inplace_functionalization.py new file mode 100644 index 0000000000..83f43bc6b4 --- /dev/null +++ b/thunder/tests/test_inplace_functionalization.py @@ -0,0 +1,124 @@ +from __future__ import annotations +from dataclasses import dataclass +from functools import partial +from collections.abc import Callable + +import torch.testing + +from thunder.core import dtypes +from thunder.core.prims import PrimIDs +from thunder.tests.framework import ops +from thunder.tests.opinfos import opinfos, OpInfo, make_number, SampleInput +from thunder.tests.make_tensor import make_tensor +from thunder.torch import _torch_to_thunder_function_map, _inplace_to_out_of_place + + +# `SampleInput`s of ops with `inplace` argument do not seem to come with `inplace` arg, so give it to them. +def sample_generator_wrapper(sample_generator): + + def f(*args, **kwargs): + for sample in sample_generator(*args, **kwargs): + yield SampleInput(*(list(sample.args) + [True]), **sample.kwargs) + + return f + + +def inplace_masked_fill_sample_generator(op, device, dtype, requires_grad, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + number = partial(make_number, dtype=dtype) + + # pred_shape, a_shape, value + cases = (((2, 2, 2), (2, 2, 2), number()),) + + for pred_shape, a_shape, value in cases: + pred, a = make(pred_shape, dtype=torch.bool, requires_grad=False), make(a_shape) + yield SampleInput(a, pred, value) + + +_torchsymbol_to_torch: dict[Sybmol, Callable] = {v: k for k, v in _torch_to_thunder_function_map.items()} +_functional_to_inplace: dict[Callable, Callable] = { + functional: inplace for inplace, (functional, index) in _inplace_to_out_of_place.items() if index == -1 +} +_functional_to_functional_with_inplace_arg: dict[Callable, tuple[Callable, int]] = { + functional: (inplace, index) for inplace, (functional, index) in _inplace_to_out_of_place.items() if index >= 0 +} +_inplace_opinfos: list[OpInfo] = [] +for op in opinfos: + if not (op.op in _functional_to_inplace or op.op in _functional_to_functional_with_inplace_arg): + continue + # ops that have an argument of `inplace` such as `F.relu` and `F.gelu` + if op.op in _functional_to_functional_with_inplace_arg: + inplace_op, _ = _functional_to_functional_with_inplace_arg[op.op] + assert op.name != "masked_fill" + inplace_opinfo = OpInfo( + inplace_op, + sample_input_generator=sample_generator_wrapper(op.sample_input_generator), + torch_reference=getattr(torch.nn.functional, op.name), + ) + _inplace_opinfos.append(inplace_opinfo) + # in-place ops whose name ends with `_` + if op.op in _functional_to_inplace: + inplace_op = _functional_to_inplace[op.op] + inplace_opinfo = OpInfo( + inplace_op, + sample_input_generator=( + op.sample_input_generator if op.name != "masked_fill" else inplace_masked_fill_sample_generator + ), + torch_reference=_torchsymbol_to_torch[inplace_op], + ) + _inplace_opinfos.append(inplace_opinfo) + + +@dataclass(frozen=True) +class InplaceOpWrapper: + torch_func: Callable + is_polygamma: bool + jitted: bool + + def __call__(self, *args, **kwargs): + # polygamma expects an int as its first argument and a tensor as its second but + # torch.Tensor.polygamma_ wants opposite; tensor first, int second. + # ref: + # - https://pytorch.org/docs/stable/special.html#torch.special.polygamma + # - https://pytorch.org/docs/stable/generated/torch.Tensor.polygamma_.html + args = list(args) + idx = int(self.is_polygamma and self.jitted) + t = args[idx] + 1.0 + args[idx] = t + + self.torch_func(*args, **kwargs) + return t + + +@ops(_inplace_opinfos, supported_dtypes=(dtypes.float32,)) +def test_functionalization(op: OpInfo, device: str, dtype: dtypes.dtype, executor, _): + import thunder + + is_polygamma = op.name == "polygamma_" + inplace_op = InplaceOpWrapper(op.torch_reference, is_polygamma, False) + jitted_inplace_op = thunder.jit( + InplaceOpWrapper(op.torch_reference, is_polygamma, True), + executors=executor.executors_list(), + ) + sample: SampleInput + for idx, sample in enumerate(op.sample_inputs(device, dtype)): + if idx > 0: + break + + args = list(sample.args) + if is_polygamma: + tmp = args[0] + args[0] = args[1] + args[1] = tmp + expected = inplace_op(*args, **sample.kwargs) + actual = jitted_inplace_op(*sample.args, **sample.kwargs) + torch.testing.assert_close(actual, expected, equal_nan=True) + + # make sure `prims.copy_` does not exist in the trace thanks to functionalization + fw_extrace = thunder.last_traces(jitted_inplace_op)[-1] + assert not list( + filter( + lambda bsym: bsym.sym.id == PrimIDs.COPY_, + fw_extrace.bound_symbols, + ) + ) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index c1275c08e5..50466e4afd 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -1,3 +1,4 @@ +from __future__ import annotations import itertools import math import operator @@ -40,6 +41,7 @@ from thunder.core.prims import get_grad, put_grad from thunder.core.baseutils import run_once + __all__ = [ "is_available", ] @@ -72,6 +74,9 @@ # torch operation definitions # +# in-place sym -> out-of-place (= functional) sym with index of `inplace` argument +_inplace_to_out_of_place: dict[Callable, tuple[Callable, int]] = {} + # A wrapper that executes the operations within the torch language context # NOTE because this module defines the torch language context, a reference to itself @@ -148,6 +153,9 @@ def __call__(self, fn: Callable) -> Symbol: for torchfn in self.torchfns: _torch_to_thunder_function_map[torchfn] = sym + if self.tags and prims.OpTags.IN_PLACE in self.tags: + _inplace_to_out_of_place[sym] = globals()[name[:-1]], -1 + return sym @@ -1241,96 +1249,191 @@ def abs(a: NumberLike | TensorLike, /) -> Number | TensorLike: return clang.abs(a) +@torchsymbol(torch.Tensor.abs_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def abs_(a: NumberLike | TensorLike, /) -> Number | TensorLike: + return prims.copy_(abs(a), a) + + @torchsymbol(torch.acos, is_method=True) def acos(a: NumberLike | TensorLike, /) -> Number | TensorLike: return clang.acos(a) +@torchsymbol(torch.Tensor.acos_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def acos_(a: TensorLike, /) -> TensorLike: + return prims.copy_(acos(a), a) + + @torchsymbol(torch.acosh, is_method=True) def acosh(a): return clang.acosh(a) +@torchsymbol(torch.Tensor.acosh_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def acosh_(a): + return prims.copy_(acosh(a), a) + + @torchsymbol(torch.asin, is_method=True) def asin(a): return clang.asin(a) +@torchsymbol(torch.Tensor.asin_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def asin_(a): + return prims.copy_(asin(a), a) + + @torchsymbol(torch.asinh, is_method=True) def asinh(a): return clang.asinh(a) +@torchsymbol(torch.Tensor.asinh_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def asinh_(a): + return prims.copy_(asinh(a), a) + + @torchsymbol(torch.atan, is_method=True) def atan(a): return clang.atan(a) +@torchsymbol(torch.Tensor.atan_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def atan_(a): + return prims.copy_(atan(a), a) + + @torchsymbol(torch.atanh, is_method=True) def atanh(a): return clang.atanh(a) +@torchsymbol(torch.Tensor.atanh_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def atanh_(a): + return prims.copy_(atanh(a), a) + + @torchsymbol(torch.bitwise_not, is_method=True) def bitwise_not(a): return clang.bitwise_not(a) +@torchsymbol(torch.Tensor.bitwise_not_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def bitwise_not_(a): + return prims.copy_(bitwise_not(a), a) + + @torchsymbol(torch.ceil, is_method=True) def ceil(a): return clang.ceil(a) +@torchsymbol(torch.Tensor.ceil_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def ceil_(a): + return prims.copy_(ceil(a), a) + + @torchsymbol(torch.cos, is_method=True) def cos(a): return clang.cos(a) +@torchsymbol(torch.Tensor.cos_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def cos_(a): + return prims.copy_(cos(a), a) + + @torchsymbol(torch.cosh, is_method=True) def cosh(a): return clang.cosh(a) +@torchsymbol(torch.Tensor.cosh_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def cosh_(a): + return prims.copy_(cosh(a), a) + + @torchsymbol(torch.digamma, torch.special.digamma, is_method=True) def digamma(a): return clang.digamma(a) +@torchsymbol(torch.Tensor.digamma_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def digamma_(a): + return prims.copy_(digamma(a), a) + + @torchsymbol(torch.erf, is_method=True) def erf(a): return clang.erf(a) +@torchsymbol(torch.Tensor.erf_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def erf_(a): + return prims.copy_(erf(a), a) + + @torchsymbol(torch.erfc, is_method=True) def erfc(a): return clang.erfc(a) +@torchsymbol(torch.Tensor.erfc_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def erfc_(a): + return prims.copy_(erfc(a), a) + + @torchsymbol(torch.erfinv, is_method=True) def erfinv(a): return clang.erfinv(a) +@torchsymbol(torch.Tensor.erfinv_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def erfinv_(a): + return prims.copy_(erfinv(a), a) + + @torchsymbol(torch.exp, is_method=True) def exp(a): return clang.exp(a) +@torchsymbol(torch.Tensor.exp_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def exp_(a): + return prims.copy_(exp(a), a) + + @torchsymbol(torch.exp2, is_method=True) def exp2(a): return clang.exp2(a) +@torchsymbol(torch.Tensor.exp2_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def exp2_(a): + return prims.copy_(exp2(a), a) + + @torchsymbol(torch.expm1, is_method=True) def expm1(a): return clang.expm1(a) +@torchsymbol(torch.Tensor.expm1_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def expm1_(a): + return prims.copy_(expm1(a), a) + + @torchsymbol(torch.floor, is_method=True) def floor(a): return clang.floor(a) +@torchsymbol(torch.Tensor.floor_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def floor_(a): + return prims.copy_(floor(a), a) + + @torchsymbol(torch.isfinite, is_method=True) def isfinite(a): return clang.isfinite(a) @@ -1341,26 +1444,51 @@ def lgamma(a): return clang.lgamma(a) +@torchsymbol(torch.Tensor.lgamma_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def lgamma_(a): + return prims.copy_(lgamma(a), a) + + @torchsymbol(torch.log, is_method=True) def log(a): return clang.log(a) +@torchsymbol(torch.Tensor.log_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def log_(a): + return prims.copy_(log(a), a) + + @torchsymbol(torch.log10, is_method=True) def log10(a): return clang.log10(a) +@torchsymbol(torch.Tensor.log10_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def log10_(a): + return prims.copy_(log10(a), a) + + @torchsymbol(torch.log1p, is_method=True) def log1p(a): return clang.log1p(a) +@torchsymbol(torch.Tensor.log1p_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def log1p_(a): + return prims.copy_(log1p(a), a) + + @torchsymbol(torch.log2, is_method=True) def log2(a): return clang.log2(a) +@torchsymbol(torch.Tensor.log2_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def log2_(a): + return prims.copy_(log2(a), a) + + # TODO Move to special # @torchsymbol(torch.ndtri, is_method=True) # def ndtri(a): @@ -1372,21 +1500,41 @@ def neg(a): return clang.neg(a) +@torchsymbol(torch.Tensor.neg_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def neg_(a): + return prims.copy_(neg(a), a) + + @torchsymbol(torch.reciprocal, is_method=True) def reciprocal(a): return clang.reciprocal(a) +@torchsymbol(torch.Tensor.reciprocal_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def reciprocal_(a): + return prims.copy_(reciprocal(a), a) + + @torchsymbol(torch.round, is_method=True) def round(a): return clang.round(a) +@torchsymbol(torch.Tensor.round_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def round_(a): + return prims.copy_(round(a), a) + + @torchsymbol(torch.rsqrt, is_method=True) def rsqrt(a): return clang.rsqrt(a) +@torchsymbol(torch.Tensor.rsqrt_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def rsqrt_(a): + return prims.copy_(rsqrt(a), a) + + # TODO Complain about complex numbers like PyTorch does? # TODO Add sgn @torchsymbol(torch.sign, is_method=True) @@ -1394,6 +1542,11 @@ def sign(a): return clang.sign(a) +@torchsymbol(torch.Tensor.sign_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def sign_(a): + return prims.copy_(sign(a), a) + + @torchsymbol(torch.signbit, is_method=True) def signbit(a): return clang.signbit(a) @@ -1404,31 +1557,61 @@ def sin(a): return clang.sin(a) +@torchsymbol(torch.Tensor.sin_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def sin_(a): + return prims.copy_(sin(a), a) + + @torchsymbol(torch.sinh, is_method=True) def sinh(a): return clang.sinh(a) +@torchsymbol(torch.Tensor.sinh_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def sinh_(a): + return prims.copy_(sinh(a), a) + + @torchsymbol(torch.sqrt, is_method=True) def sqrt(a): return clang.sqrt(a) +@torchsymbol(torch.Tensor.sqrt_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def sqrt_(a): + return prims.copy_(sqrt(a), a) + + @torchsymbol(torch.tan, is_method=True) def tan(a): return clang.tan(a) +@torchsymbol(torch.Tensor.tan_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def tan_(a): + return prims.copy_(tan(a), a) + + @torchsymbol(torch.tanh, is_method=True) def tanh(a): return clang.tanh(a) +@torchsymbol(torch.Tensor.tanh_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def tanh_(a): + return prims.copy_(tanh(a), a) + + @torchsymbol(torch.trunc, is_method=True) def trunc(a): return clang.trunc(a) +@torchsymbol(torch.Tensor.trunc_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def trunc_(a): + return prims.copy_(trunc(a), a) + + @torchsymbol(torch.real, is_method=False) def real(a): return clang.real(a) @@ -1457,48 +1640,82 @@ def gelu(a: TensorProxy, /, *, approximate: str = "none") -> TensorLike: # TODO Should this use clamp? -- Would that propagate NaNs properly? @torchsymbol(torch.relu, torch.nn.functional.relu, id="torch.relu", is_method=True) def relu(a: TensorLike, /, inplace: bool = False) -> TensorLike: - utils.check(not inplace, lambda: f"relu only supports inplace=False", exception_type=NotImplementedError) - return where(a > 0, a, 0) + out = where(a > 0, a, 0) + if inplace: + return prims.copy_(out, a) + return out + + +_inplace_to_out_of_place[relu] = relu, 1 + + +@torchsymbol(torch.relu_, torch.nn.functional.relu_, id="torch.relu_", is_method=True) +def relu_( + a: TensorLike, + /, +) -> TensorLike: + return prims.copy_(relu(a, False), a) + + +# The default value of `inplace` is False, so no need to tweak args/kwargs +_inplace_to_out_of_place[relu_] = relu, -1 # id=torch.relu because we ignore inplace argument in torch.nn.functional.relu @torchsymbol(torch.nn.functional.relu6, id="torch.relu6", is_method=False) def relu6(a: TensorProxy, /, inplace: bool = False) -> TensorLike: - utils.check(not inplace, lambda: f"relu6 only supports inplace=False", exception_type=NotImplementedError) - return clamp(a, 0, 6) + out = clamp(a, 0, 6) + if inplace: + return prims.copy_(out, a) + return out + + +_inplace_to_out_of_place[relu6] = relu6, 1 @torchsymbol(torch.nn.functional.hardswish, id="torch.hardswish", is_method=False) def hardswish(a: TensorProxy, /, inplace: bool = False) -> TensorLike: - utils.check(not inplace, lambda: f"hardswish only supports inplace=False", exception_type=NotImplementedError) utils.check( dtypes.is_float_dtype(a.dtype), lambda: f"hardswish only supports floating point dtypes, got {a.dtype}", exception_type=ValueError, ) - return a * relu6(a + 3) / 6 + out = a * relu6(a + 3) / 6 + if inplace: + return prims.copy_(out, a) + return out + + +_inplace_to_out_of_place[hardswish] = hardswish, 1 # id=torch.selu because we ignore inplace argument in torch.nn.functional.selu @torchsymbol(torch.selu, torch.nn.functional.selu, id="torch.selu", is_method=False) def selu(a: TensorProxy, /, inplace: bool = False) -> TensorLike: - utils.check(not inplace, lambda: f"selu only supports inplace=False", exception_type=NotImplementedError) - alpha = 1.6732632423543772848170429916717 scale = 1.0507009873554804934193349852946 rhs = alpha * expm1(a) - return scale * where(a > 0, a, rhs) + out = scale * where(a > 0, a, rhs) + if inplace: + return prims.copy_(out, a) + return out + + +_inplace_to_out_of_place[selu] = selu, 1 @torchsymbol(torch.nn.functional.silu) def silu(a: TensorLike, /, inplace: bool = False) -> TensorLike: - utils.check( - not inplace, lambda: "Thunder only supports silu with inplace=False", exception_type=NotImplementedError - ) - return clang.silu(a) + out = clang.silu(a) + if inplace: + return prims.copy_(out, a) + return out + + +_inplace_to_out_of_place[silu] = silu, 1 # @@ -1516,31 +1733,67 @@ def add( return clang.add(a, b) +@torchsymbol(torch.Tensor.add_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def add_( + a: TensorLike, + b: NumberLike | TensorLike, + /, + *, + alpha: None | Number | TensorLike = None, +) -> TensorLike: + return prims.copy_(add(a, b, alpha=alpha), a) + + @torchsymbol(torch.atan2, is_method=True) def atan2(a, b, /): return clang.atan2(a, b) +@torchsymbol(torch.Tensor.atan2_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def atan2_(a, b, /): + return prims.copy_(atan2(a, b), a) + + @torchsymbol(torch.bitwise_and, is_method=True) def bitwise_and(a, b, /): return clang.bitwise_and(a, b) +@torchsymbol(torch.Tensor.bitwise_and_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def bitwise_and_(a, b, /): + return prims.copy_(bitwise_and(a, b), a) + + @torchsymbol(torch.bitwise_or, is_method=True) def bitwise_or(a, b, /): return clang.bitwise_or(a, b) +@torchsymbol(torch.Tensor.bitwise_or_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def bitwise_or_(a, b, /): + return prims.copy_(bitwise_or(a, b), a) + + @torchsymbol(torch.bitwise_xor, is_method=True) def bitwise_xor(a, b, /): return clang.bitwise_xor(a, b) +@torchsymbol(torch.Tensor.bitwise_xor_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def bitwise_xor_(a, b, /): + return prims.copy_(bitwise_xor(a, b), a) + + @torchsymbol(torch.copysign, is_method=True) def copysign(a, b, /): return clang.copysign(a, b) +@torchsymbol(torch.Tensor.copysign_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def copysign_(a, b, /): + return prims.copy_(copysign(a, b), a) + + # TODO Implement div @torchsymbol(torch.div, is_method=True) def div( @@ -1563,51 +1816,107 @@ def div( raise ValueError(f"div does not support the rounding_mode={rounding_mode} argument") +@torchsymbol(torch.Tensor.div_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def div_( + a: TensorLike, + b: Number | TensorLike, + /, + *, + rounding_mode: None | str = None, +) -> TensorLike: + return prims.copy_(div(a, b), a) + + @torchsymbol(torch.eq, is_method=True) def eq(a, b, /): return clang.eq(a, b) +@torchsymbol(torch.Tensor.eq_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def eq_(a, b, /): + return prims.copy_(eq(a, b), a) + + @torchsymbol(torch.floor_divide, is_method=True) def floor_divide(a, b, /): return clang.floor_divide(a, b) +@torchsymbol(torch.Tensor.floor_divide_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def floor_divide_(a, b, /): + return prims.copy_(floor_divide(a, b), a) + + @torchsymbol(torch.fmod, is_method=True) def fmod(a, b, /): return clang.fmod(a, b) +@torchsymbol(torch.Tensor.fmod_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def fmod_(a, b, /): + return prims.copy_(fmod(a, b), a) + + @torchsymbol(torch.ge, is_method=True) def ge(a, b, /): return clang.ge(a, b) +@torchsymbol(torch.Tensor.ge_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def ge_(a, b, /): + return prims.copy_(ge(a, b), a) + + @torchsymbol(torch.gt, is_method=True) def gt(a, b, /): return clang.gt(a, b) +@torchsymbol(torch.Tensor.gt_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def gt_(a, b, /): + return prims.copy_(gt(a, b), a) + + @torchsymbol(torch.logical_and, is_method=True) def logical_and(a, b, /): return clang.logical_and(a, b) +@torchsymbol(torch.Tensor.logical_and_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def logical_and_(a, b, /): + return prims.copy_(logical_and(a, b), a) + + @torchsymbol(torch.logical_not, is_method=True) def logical_not(a: TensorLike, /) -> TensorLike: return clang.logical_not(a) +@torchsymbol(torch.Tensor.logical_not_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def logical_not_(a: TensorLike, /) -> TensorLike: + return prims.copy_(logical_not(a), a) + + @torchsymbol(torch.le, is_method=True) def le(a, b, /): return clang.le(a, b) +@torchsymbol(torch.Tensor.le_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def le_(a, b, /): + return prims.copy_(le(a, b), a) + + @torchsymbol(torch.lt, is_method=True) def lt(a, b, /): return clang.lt(a, b) +@torchsymbol(torch.Tensor.lt_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def lt_(a, b, /): + return prims.copy_(lt(a, b), a) + + @torchsymbol(torch.maximum, is_method=True) def maximum(a: TensorProxy, b: TensorProxy) -> TensorProxy: return clang.maximum(a, b) @@ -1625,21 +1934,40 @@ def mod(a, b): return clang.mod(a, b) +def mod_(a, b): + return prims.copy_(mod(a, b), a) + + @torchsymbol(torch.mul, is_method=True) def mul(a, b, /): return clang.mul(a, b) +@torchsymbol(torch.Tensor.mul_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def mul_(a, b, /): + return prims.copy_(mul(a, b), a) + + @torchsymbol(torch.ne, is_method=True) def ne(a, b, /): return clang.ne(a, b) +@torchsymbol(torch.Tensor.ne_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def ne_(a, b, /): + return prims.copy_(ne(a, b), a) + + @torchsymbol(torch.nextafter, is_method=True) def nextafter(a, b, /): return clang.nextafter(a, b) +@torchsymbol(torch.Tensor.nextafter_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def nextafter_(a, b, /): + return prims.copy_(nextafter(a, b), a) + + # TODO Extend to tensor x tensor @torchsymbol(torch.polygamma, torch.special.polygamma, is_method=True) def polygamma(n: int, a: TensorLike, /) -> TensorLike: @@ -1658,16 +1986,31 @@ def polygamma(n: int, a: TensorLike, /) -> TensorLike: return sign * factorial_mul_zeta +@torchsymbol(torch.Tensor.polygamma_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def polygamma_(n: int, a: TensorLike, /) -> TensorLike: + return prims.copy_(polygamma(n, a), a) + + @torchsymbol(torch.pow, is_method=True) def pow(a, b, /): return clang.pow(a, b) +@torchsymbol(torch.Tensor.pow_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def pow_(a, b, /): + return prims.copy_(pow(a, b), a) + + @torchsymbol(torch.remainder, is_method=True) def remainder(a, b, /): return clang.remainder(a, b) +@torchsymbol(torch.Tensor.remainder_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def remainder_(a, b, /): + return prims.copy_(remainder(a, b), a) + + @torchsymbol(torch.sub, is_method=True) def sub(a, b, /, *, alpha=None): if alpha is not None: @@ -1676,11 +2019,21 @@ def sub(a, b, /, *, alpha=None): return clang.sub(a, b) +@torchsymbol(torch.Tensor.sub_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def sub_(a, b, /, *, alpha=None): + return prims.copy_(sub(a, b, alpha=alpha), a) + + @torchsymbol(torch.true_divide, is_method=True) def true_divide(a: NumberLike | TensorLike, b: NumberLike | TensorLike, /) -> Number | TensorLike: return clang.true_divide(a, b) +@torchsymbol(torch.Tensor.true_divide_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def true_divide_(a: TensorLike, b: NumberLike | TensorLike, /) -> TensorLike: + return prims.copy_(true_divide(a, b)) + + @torchsymbol(torch.special.zeta) def zeta(a, b, /): return clang.zeta(a, b) @@ -1715,11 +2068,21 @@ def addcmul(a: TensorLike, b: TensorLike, c: TensorLike, /, *, value: None | Num return addcmul_addcdiv_helper(a, b, c, add, mul, value=value) +@torchsymbol(torch.Tensor.addcmul_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def addcmul_(a: TensorLike, b: TensorLike, c: TensorLike, /, *, value: None | Number = None) -> TensorLike: + return prims.copy_(addcmul(a, b, c, value=value), a) + + @torchsymbol(torch.addcdiv, is_method=True) def addcdiv(a: TensorLike, b: TensorLike, c: TensorLike, /, *, value: None | Number = None) -> TensorLike: return addcmul_addcdiv_helper(a, b, c, add, true_divide, value=value) +@torchsymbol(torch.Tensor.addcdiv_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def addcdiv_(a: TensorLike, b: TensorLike, c: TensorLike, /, *, value: None | Number = None) -> TensorLike: + return prims.copy_(addcdiv(a, b, c, value=value), a) + + # # Conditional operations and masking operations # @@ -1760,6 +2123,13 @@ def clamp( return a +@torchsymbol(torch.Tensor.clamp_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def clamp_( + a: TensorLike, /, min: None | Number | TensorLike = None, max: None | Number | TensorLike = None +) -> TensorLike: + return prims.copy_(clamp(a, min, max), a) + + def _mask_tensor(a, mask, fill_value): utils.check( dtypes.is_boolean_dtype(mask.dtype), lambda: f"_mask_tensor: mask ({mask.dtype=}) must have a boolean dtype" @@ -1783,6 +2153,11 @@ def masked_fill(a: TensorLike, /, mask: TensorLike, value: NumberLike | TensorLi return result +@torchsymbol(torch.Tensor.masked_fill_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def masked_fill_(a: TensorLike, /, mask: TensorLike, value: NumberLike | TensorLike) -> TensorLike: + return prims.copy_(masked_fill(a, mask, value), a) + + # NOTE The key to understanding tril is that it generates a mask # which (by default) masks elements of a matrix (or batch of matrices) # s.t. elements whose row number is greater than or equal to its column number @@ -1805,6 +2180,11 @@ def tril(a: TensorLike, /, diagonal: int = 0, *, fill_value: None | Number = Non return _mask_tensor(a, mask, fill_value) +@torchsymbol(torch.Tensor.tril_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def tril_(a: TensorLike, /, diagonal: int = 0, *, fill_value: None | Number = None) -> TensorLike: + return prims.copy_(tril(a, diagonal, fill_value=fill_value), a) + + @torchsymbol(torch.where, is_method=True) def where( pred: TensorLike, a: None | Number | TensorLike = None, b: None | Number | TensorLike = None, / @@ -2225,6 +2605,11 @@ def cumsum(a: TensorLike, dim: int, *, dtype: None | dtypeLike = None) -> Tensor return TensorProxy(like=a, dtype=to_dtype(dtype)) +@torchsymbol(torch.Tensor.cumsum_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def cumsum_(a: TensorLike, dim: int, *, dtype: None | dtypeLike = None) -> TensorLike: + return prims.copy_(cumsum(a, dim, dtype=dtype), a) + + @torchsymbol(torch.var, is_method=True) def var( a: TensorProxy, @@ -2295,6 +2680,11 @@ def index_add(a: TensorLike, /, dim: int, index: TensorLike, source: TensorLike) return clang.index_add(a, index, source, dim) +@torchsymbol(torch.Tensor.index_add_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def index_add_(a: TensorLike, /, dim: int, index: TensorLike, source: TensorLike) -> TensorLike: + return prims.copy_(index_add(a, dim, index, source), a) + + @torchsymbol(torch.index_select, is_method=True) def index_select(a: TensorLike, /, dim: int, index: TensorLike) -> TensorLike: return clang.take(a, index, dim) @@ -2311,6 +2701,11 @@ def scatter_add(a: TensorLike, /, dim: int, index: TensorLike, src: TensorLike) return clang.scatter_add(a, indices=index, value=src, dim=dim) +@torchsymbol(torch.Tensor.scatter_add_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def scatter_add_(a: TensorLike, /, dim: int, index: TensorLike, src: TensorLike) -> TensorLike: + return prims.copy_(scatter_add(a, dim, index, src), a) + + @torchsymbol(torch.take_along_dim) def take_along_dim(a: TensorLike, /, indices: TensorLike, dim: int) -> TensorLike: return clang.take_along_axis(a, indices, dim) @@ -2323,6 +2718,17 @@ def index_put( return clang.index_put(a, indices, values, accumulate) +@torchsymbol(torch.Tensor.index_put_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def index_put_( + a: TensorLike, + /, + indices: Sequence[TensorLike], + values: TensorLike, + accumulate: bool = False, +) -> TensorLike: + return prims.copy_(index_put(a, indices, values, accumulate), a) + + # # Linear Algebra operations # @@ -3897,7 +4303,13 @@ def dropout(a: TensorProxy, /, p: NumberLike = 0.5, training: bool = True, inpla scale = 1 / (1 - p) dropout_mask = _dropout_helper(a, 1 - p) - return a * dropout_mask * scale + out = a * dropout_mask * scale + if inplace: + return prims.copy_(out, a) + return out + + +_inplace_to_out_of_place[dropout] = dropout, 3 @torchsymbol(torch.nn.functional.embedding, id="torch.nn.functional.embedding") From a78f3c9c769bf355e0ec9611eba1ffd364aee6ab Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Thu, 20 Jun 2024 16:14:40 +0200 Subject: [PATCH 05/14] add quantization transform (#561) --- requirements/test.txt | 1 + thunder/__init__.py | 2 + thunder/core/module.py | 25 ++++ thunder/core/transform_common.py | 20 ++- thunder/tests/test_networks.py | 59 +++++++++ thunder/transforms/quantization.py | 202 +++++++++++++++++++++++++++++ thunder/transforms/utils.py | 49 +++++++ 7 files changed, 357 insertions(+), 1 deletion(-) create mode 100644 thunder/transforms/quantization.py create mode 100644 thunder/transforms/utils.py diff --git a/requirements/test.txt b/requirements/test.txt index 85a3c79ab4..5d920dcd59 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -17,6 +17,7 @@ absl-py # thunder/benchmarks/test_benchmark_litgpt.py pandas # thunder/benchmarks/test_benchmark_litgpt.py xlsxwriter # thunder/benchmarks/test_benchmark_litgpt.py jsonargparse # thunder/benchmarks/benchmark_litgpt.py +bitsandbytes==0.42.0 # fixed version! # Installs JAX on Linux and MacOS jaxlib; sys_platform == 'linux' or sys_platform == 'darwin' # required for jax, see https://github.com/google/jax#installation diff --git a/thunder/__init__.py b/thunder/__init__.py index e647a42aa0..04274f689e 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -702,6 +702,8 @@ def fn_(*args, **kwargs) -> Any: if isinstance(fn, pytorch.nn.Module): fn_ = ThunderModule(fn, fn_) cd._thunder_module_map[id(fn)] = fn_ + for transform in early_transforms: + transform.transform_module(fn_) # Sets compile options and statistics attributes cd._get_computation_and_inputs = get_computation_and_inputs diff --git a/thunder/core/module.py b/thunder/core/module.py index b2f804db5e..8483841f2a 100644 --- a/thunder/core/module.py +++ b/thunder/core/module.py @@ -1,6 +1,7 @@ from contextlib import contextmanager import itertools from typing import Any +import collections import torch as pytorch @@ -97,6 +98,30 @@ def named_buffers(self, prefix="", recurse=True, remove_duplicate=True): remove_duplicate=remove_duplicate, ) + def load_original_state_dict(self, state_dict): + # this loads the state dict incrementally to not exhaust memory + module_names = {n for n, _ in self.named_modules()} + sd_per_module = collections.defaultdict(dict) + for k, v in state_dict.items(): + prefix, sep, _ = k.rpartition(".") + # not great but should not happen too often / deep + while prefix not in module_names: + prefix, sep, _ = prefix.rpartition(".") + sd_per_module[prefix][k[len(prefix) + len(sep) :]] = v + + for submodule_name, sd_part in sd_per_module.items(): + prefix = submodule_name + ("." if submodule_name else "") + for transform in self._lc_early_transforms: + sd_part = transform.transform_state_dict_for_submodule(self, submodule_name, sd_part) + for k, v in sd_part.items(): + full_k = prefix + k + if k in self._overrides_parameters: + self._overrides_parameters[full_k] = v + elif k in model._overrides_buffers: + self._overrides_buffers[full_k] = v + else: + raise NotImplementedError(f"don't know how to handle {full_k}") + @contextmanager def no_sync(self): r"""Context manager to disable gradient synchronization in data parallel mode. diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index 4b71d14b45..158a9e32f6 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -6,6 +6,7 @@ from itertools import filterfalse from functools import partial +import thunder import thunder.core.prims as prims from thunder.core.baseutils import BoundSymbolInterface from thunder.core.proxies import Proxy, variableify, Variable, TensorProxy @@ -343,10 +344,27 @@ class EarlyTransform(Transform, ABC): the computation trace will also update backward trace. """ - @abstractmethod def transform_traces(self, prologue_trace: Trace, computation_trace: Trace, epilogue_trace: Trace | None, **kwargs): + # default to noop + return prologue_trace, computation_trace, epilogue_trace + + def transform_module(self, model: thunder.ThunderModule): + """Transforms the ThunderModule. This is executed once on application of the transform""" pass + def transform_state_dict_for_submodule( + self, model: thunder.ThunderModule, submodule_name: str, state_dict: dict + ) -> dict: + """ + Implement this to transform the state dict (mostly parameters and buffers) of a module, e.g. when loading + from a state dict of the original model. + + Expected to return a state dict (for chaining or populating overrides). + + Note that state dict keys do not include the submodule name as prefix. + """ + return state_dict + class AdditionalTransform(Transform, ABC): """ diff --git a/thunder/tests/test_networks.py b/thunder/tests/test_networks.py index cac1666cab..70ea083606 100644 --- a/thunder/tests/test_networks.py +++ b/thunder/tests/test_networks.py @@ -204,3 +204,62 @@ def test_hf_bart_self_attn(): tom = thunder.jit(model) thunder_result = tom(inp, None) assert_close(torch_result, thunder_result) + + +@requiresCUDA +def test_quantization(): + from thunder.tests import litgpt_model + from lightning.fabric.plugins import BitsandbytesPrecision + + config = litgpt_model.Config.from_name("llama2-like") + with torch.device("cuda"): + model_fp_reference = litgpt_model.GPT(config).to(torch.bfloat16) + + import lightning as L + + plugins = BitsandbytesPrecision("nf4", torch.bfloat16) + fabric = L.Fabric(devices=1, precision=None, plugins=plugins) + with fabric.init_module(empty_init=True): + model = litgpt_model.GPT(config) + + with fabric.init_tensor(): + # set the max_seq_length to limit the memory usage to what we need + model.max_seq_length = 20 + # enable the kv cache + model.set_kv_cache(batch_size=1) + model.eval() + model.requires_grad_(False) + model = fabric.setup_module(model) + + model.load_state_dict(model_fp_reference.state_dict()) + + x = torch.randint(1, 255, (1, 10), device="cuda") + input_pos = torch.arange(10, device="cuda") + logits_expected = model(x, input_pos) + + from thunder.transforms.quantization import BitsAndBytesLinearQuant4bit, get_bitsandbytes_executor + + bitsandbytes_executor = get_bitsandbytes_executor() + + model_fp_reference.set_kv_cache(1, device="cuda", dtype=torch.bfloat16) + model_fp_reference.max_seq_length = 20 + model_fp_reference.requires_grad_(False) + model_fp_reference.eval() + + jm = thunder.jit( + model_fp_reference, + executors=(bitsandbytes_executor,), + early_transforms=[BitsAndBytesLinearQuant4bit()], + ) + + logits_thunder = jm(x, input_pos) + # check_dtype=False due to litgpt returning float32 + # (maybe that also is the numerical discrepancy?) + assert_close(logits_thunder, logits_expected, atol=2e-2, rtol=1e-3, check_dtype=False) + + sd = {k: v.clone() for k, v in jm.state_dict().items()} + jm.load_original_state_dict(model_fp_reference.state_dict()) + sd2 = {k: v.clone() for k, v in jm.state_dict().items()} + assert len(sd) == len(sd2) + for k, v in sd.items(): + assert_close(v, sd2[k]) diff --git a/thunder/transforms/quantization.py b/thunder/transforms/quantization.py new file mode 100644 index 0000000000..ab86608d49 --- /dev/null +++ b/thunder/transforms/quantization.py @@ -0,0 +1,202 @@ +from collections.abc import Sequence + +import thunder +from thunder.core.transform_common import EarlyTransform +from thunder.core import utils +from thunder.core import prims +import torch + +from .utils import ( + get_orig_and_thunder_module_proxies_from_prologue, + get_checks, + add_trace_output, +) + + +bitsandbytes_executor = None + + +def get_bitsandbytes_executor(): + global bitsandbytes + global bitsandbytes_executor + global bnb_matmul_nf4 + + if bitsandbytes_executor is None: + import bitsandbytes + + bitsandbytes_executor = thunder.extend.OperatorExecutor("quant_bnb", version=0.1) + + def bnb_matmul_nf4_meta(x, qweight, bias, absmax, quant_map, blocksize, dtype, shape): + assert isinstance(shape, Sequence) and len(shape) == 2 + assert x.shape[-1] == shape[1], f"{x.shape=}, rhs {shape=}" + return thunder.TensorProxy(like=x, shape=(*x.shape[:-1], shape[0])) + + def bnb_matmul_nf4_impl(x, qweight, bias, absmax, quant_map, blocksize, dtype, shape): + qs = bitsandbytes.functional.QuantState( + absmax, shape=shape, blocksize=blocksize, code=quant_map, quant_type="nf4", dtype=dtype + ) + + return bitsandbytes.matmul_4bit(x, qweight.t(), bias=bias, quant_state=qs) + + bnb_matmul_nf4 = bitsandbytes_executor.register_operator( + "bnb_matmul_nf4", meta=bnb_matmul_nf4_meta, fn=bnb_matmul_nf4_impl + ) + return bitsandbytes_executor + + +class BitsAndBytesLinearQuant4bit(EarlyTransform): + def __init__(self): + self.quant_states = {} + self.quantized_submodule_names = set() + get_bitsandbytes_executor() + + def transform_module(self, model: thunder.ThunderModule): + self.thunder_module = model + + def convert_linear_submodule(tm, name): + self.quantized_submodule_names.add(name) + weight_name = f"{name}.weight" + w = tm.get_parameter(weight_name) + # device!, double quant support + qw, qs = bitsandbytes.functional.quantize_4bit(w.to("cuda"), quant_type="nf4") + tm._overrides_parameters[weight_name] = qw + tm._overrides_parameters[f"{weight_name}.absmax"] = qs.absmax + tm._overrides_parameters[f"{weight_name}.code"] = qs.code + self.quant_states[weight_name] = {"dtype": qs.dtype, "shape": qs.shape, "blocksize": qs.blocksize} + + for n, submodule in model._model.named_modules(): + if isinstance(submodule, torch.nn.Linear): + convert_linear_submodule(model, n) + + def transform_state_dict_for_submodule(self, model: thunder.ThunderModule, submodule_name: str, state_dict: dict): + # note that state dict entries do not include the submodule name as prefix + if submodule_name not in self.quantized_submodule_names: + return state_dict + weight_name_full = f"{submodule_name}.weight" + qs_dict = self.quant_states[weight_name_full] + w = state_dict["weight"] + assert w.dtype == qs_dict["dtype"] + assert w.shape == qs_dict["shape"] + + qw, qs = bitsandbytes.functional.quantize_4bit(w.to("cuda"), block_size=qs_dict["blocksize"], quant_type="nf4") + + # double quant support... + state_dict = state_dict.copy() + state_dict[weight_name] = qw + state_dict[f"{weight_name}.absmax"] = qs.absmax + state_dict[f"{weight_name}.code"] = qs.code + + return state_dict + + def transform_traces(self, prologue_trace, computation_trace, epilogue_trace, **kwargs): + tm = self.thunder_module + from thunder.core.trace import tracectx + + checks = get_checks(prologue_trace) + + compute_producers, compute_consumers = utils.producers_and_consumers(computation_trace) + + # This is needed because having epilogues adds an additional return tuple + # TODO: unify after https://github.com/Lightning-AI/lightning-thunder/issues/628 + if epilogue_trace is None: + prologue_to_epilogue_outputs = prologue_trace.output + output_subindex = None + else: + prologue_to_epilogue_outputs = prologue_trace.output[0] + output_subindex = 0 + + output_idxes = {id(o): i for i, o in enumerate(prologue_to_epilogue_outputs)} + + computation_trace.push_scope([]) + quantized_proxies: dict[int, str] = {} # id -> name + + new_bsyms = [] + new_compute_inputs = [] + for n, qs in self.quant_states.items(): + param = tm.get_parameter(n) + n_absmax = f"{n}.absmax" + n_code = f"{n}.code" + param_absmax = tm.get_parameter(n_absmax) + param_code = tm.get_parameter(n_code) + check, get_param = checks[n] + quantized_proxies[id(get_param.output)] = n + # check has args: tensor, shape, device, dtype, requires_grad + proxy, _, _, _, requires_grad = check.args + thunder_device = thunder.devices.to_device(param.device) + thunder_device_str = str(thunder_device) + check.args = (proxy, (*param.shape,), thunder_device_str, param.dtype, False) + + output_idx = output_idxes.get(id(get_param.output)) + if output_idx is not None: + with tracectx(prologue_trace): + # better way + proxy_absmax = thunder.TensorProxy( + name=f"{get_param.output.name}_absmax", + shape=param_absmax.shape, + dtype=thunder.dtypes.to_dtype(param_absmax.dtype), + device=thunder.devices.to_device(param_absmax.device), + requires_grad=False, + ) + proxy_code = thunder.TensorProxy( + name=f"{get_param.output.name}_code", + shape=param_code.shape, + dtype=thunder.dtypes.to_dtype(param_code.dtype), + device=thunder.devices.to_device(param_code.device), + requires_grad=False, + ) + # get_param.sym = unpack_buffer/parameter as needed + new_bsyms.append(get_param.sym.bind(get_param.args[0], n_absmax, output=proxy_absmax)) + new_bsyms.append(get_param.sym.bind(get_param.args[0], n_code, output=proxy_code)) + add_trace_output(prologue_trace, proxy_absmax, subindex=output_subindex) + add_trace_output(prologue_trace, proxy_code, subindex=output_subindex) + new_compute_inputs.append(proxy_absmax) + new_compute_inputs.append(proxy_code) + qs["proxy_absmax"] = proxy_absmax + qs["proxy_code"] = proxy_code + compute_input = computation_trace.args[output_idx] + + prologue_trace.bound_symbols[-1:-1] = new_bsyms + + with tracectx(computation_trace): + new_bindings = [thunder.core.prims.unpack_trivial.bind(i, output=i) for i in new_compute_inputs] + + new_computation_trace = thunder.core.trace.from_trace(computation_trace) + new_computation_trace.args = (*new_computation_trace.args, *new_compute_inputs) + new_computation_trace._siginfo.args = [(a.name, None) for a in new_computation_trace.args] + for idx, bsym in enumerate(computation_trace.bound_symbols): + if bsym.sym != prims.unpack_trivial: + break + new_computation_trace.bound_symbols.append(bsym.from_bsym()) + new_computation_trace.bound_symbols += new_bindings + proxies_to_replace = {} + for bsym in computation_trace.bound_symbols[idx:]: + if bsym.sym == thunder.torch.linear and id(bsym.args[1]) in quantized_proxies: + assert len(bsym.args) == 3 # torch.linear(input, weight, bias) + n = quantized_proxies[id(bsym.args[1])] + qs = self.quant_states[n] + # signature of the new symbol: + # bnb_matmul_nf4(x, qweight, bias, absmax, quant_map, blocksize, dtype, shape) + new_args = ( + *bsym.args[:3], + qs["proxy_absmax"], + qs["proxy_code"], + qs["blocksize"], + qs["dtype"], + qs["shape"], + ) + mm_bsym = bsym.from_bsym( + sym=bnb_matmul_nf4, + subsymbols=[], + args=new_args, + ) + + new_computation_trace.bound_symbols.append(mm_bsym) + # we need the postprocess to set the internal state (call_ctx) because we do not bind / execute the new symbol to + # preserve the "meta"-info like source location, header, etc. + # TODO: switch to a better solution when it is there + bnb_matmul_nf4._bind_postprocess(mm_bsym) + else: + new_computation_trace.bound_symbols.append(bsym.from_bsym()) + + new_computation_trace.set_provenance(thunder.core.trace.TraceProvenance("quant pass")) + return prologue_trace, new_computation_trace, epilogue_trace diff --git a/thunder/transforms/utils.py b/thunder/transforms/utils.py new file mode 100644 index 0000000000..5b0df34b73 --- /dev/null +++ b/thunder/transforms/utils.py @@ -0,0 +1,49 @@ +from thunder.core import utils +from thunder.core import prims + + +def get_orig_and_thunder_module_proxies_from_prologue(prologue_trace): + modules_and_thunder_modules = [ + (bsym.args[0], bsym.output) for bsym in prologue_trace.bound_symbols if bsym.sym is prims.unpack_thunder_module + ] + + if len(modules_and_thunder_modules) != 1: + raise NotImplementedError("cannot deal with modules other than the compiled module") + + ((orig_module_proxy, thunder_module_proxy),) = modules_and_thunder_modules + if prologue_producers[orig_module_proxy].sym is not prims.unpack_function_obj: + raise NotImplementedError("original module does not match the compiled module") + + return orig_module_proxy, thunder_module_proxy + + +def get_checks(prologue_trace): + check_dict = {} + prologue_producers, prologue_consumers = utils.producers_and_consumers(prologue_trace) + for bsym in prologue_trace.bound_symbols: + if bsym.sym == prims.unpack_parameter or bsym.sym == prims.unpack_buffer: + param_thunder_module, param_name = bsym.args + checks = [ + bsym2 for bsym2 in prologue_consumers[bsym.output] if bsym2.sym == prims.check_tensor_shape_and_metadata + ] + assert ( + len(checks) == 1 + ), f"expected each parameter and buffer to have exactly one checker, but {bsym.output} has {len(checks)}" + assert isinstance(param_name, str) + check_dict[param_name] = (checks[0], bsym) + return check_dict + + +def add_trace_output(trace, output, subindex: int | None = None): + ret_node = trace.bound_symbols[-1] + assert ret_node.sym == prims.python_return + assert len(ret_node.args) == 1 + (ret_args,) = ret_node.args + + if subindex is None: + ret_args = (*ret_args, output) + else: + assert isinstance(ret_args[subindex], tuple) + ret_args = (*ret_args[:subindex], (*ret_args[subindex], output), *ret_args[subindex + 1 :]) + + ret_node.args = (ret_args,) From a5d5a0c6927d500171017557c10f970dc053e775 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Fri, 21 Jun 2024 02:38:57 +0900 Subject: [PATCH 06/14] Disallow in-place to view tensors (#630) --- thunder/__init__.py | 2 ++ thunder/core/transform_common.py | 23 ++++++++++++++++ .../tests/test_inplace_functionalization.py | 27 ++++++++++++++++++- 3 files changed, 51 insertions(+), 1 deletion(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 04274f689e..5026921140 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -39,6 +39,7 @@ AdditionalTransform, PostOptimizationTransform, functionalize_inplace_ops, + check_inplace_to_views, ) from thunder.common import ( CompileData, @@ -509,6 +510,7 @@ def get_computation_and_inputs(*args, **kwargs): prologue_traces = [prologue_trc] computation_traces = [computation_trc] + check_inplace_to_views(computation_trc) if not compile_options.get("skip_inplace_functionalization", False): computation_traces.extend(functionalize_inplace_ops(computation_trace=computation_trc)) computation_trc = computation_traces[-1] diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index 158a9e32f6..84a35485a2 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -388,6 +388,29 @@ def transform_trace(self, computation_trace: Trace, **kwargs): pass +def check_inplace_to_views(computation_trace: Trace) -> None: + """Error out if ``computation_trace`` has any in-place op of `torch.reshape`'s output.""" + from thunder.core import utils + import thunder.torch as ltorch + + producer_bsyms = producers(computation_trace) + + bsym: BoundSymbol + for bsym in filter(lambda b: has_tags(b, {prims.OpTags.IN_PLACE}), computation_trace.bound_symbols): + for in_tensor in filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_args): + prod_bsym: BoundSymbol = producer_bsyms[in_tensor] + utils.check( + not has_tags(prod_bsym, {prims.OpTags.SHAPE_OP}), + lambda: f"in-place op to view tensors is not allowed but `{bsym.sym.id}` takes `{prod_bsym.sym.id}` output `{in_tensor}`", + NotImplementedError, + ) + utils.check( + prod_bsym.sym != ltorch.contiguous, + lambda: f"in-place op to `torch.Tensor.contiguous` output is not allowed but `{bsym.sym.id}` takes `{prod_bsym.sym.id}` output `{in_tensor}`", + NotImplementedError, + ) + + def functionalize_inplace_ops(computation_trace: Trace) -> list[Trace]: """Functionalize in-place ops in ``computation_trace``. diff --git a/thunder/tests/test_inplace_functionalization.py b/thunder/tests/test_inplace_functionalization.py index 83f43bc6b4..b13299710f 100644 --- a/thunder/tests/test_inplace_functionalization.py +++ b/thunder/tests/test_inplace_functionalization.py @@ -1,8 +1,9 @@ from __future__ import annotations +from collections.abc import Callable from dataclasses import dataclass from functools import partial -from collections.abc import Callable +import pytest import torch.testing from thunder.core import dtypes @@ -122,3 +123,27 @@ def test_functionalization(op: OpInfo, device: str, dtype: dtypes.dtype, executo fw_extrace.bound_symbols, ) ) + + +def test_invalid_cases(): + import thunder + + a = torch.randn((2, 2)) + + def f_with_reshape(a: torch.Tensor) -> torch.Tensor: + b = torch.reshape(a, (-1,)) + b.exp_() + return b + + jitted = thunder.jit(f_with_reshape) + with pytest.raises(NotImplementedError, match="in-place op to view tensors is not allowed but"): + jitted(a) + + def f_with_contiguous(a: torch.Tensor) -> torch.Tensor: + b = a.contiguous() + b.exp_() + return b + + jitted = thunder.jit(f_with_contiguous) + with pytest.raises(NotImplementedError, match="in-place op to `torch.Tensor.contiguous`"): + jitted(a) From a87d2de4da77fd0430db61499ef554adb234ef74 Mon Sep 17 00:00:00 2001 From: Taylor Robie Date: Thu, 20 Jun 2024 13:25:26 -0700 Subject: [PATCH 07/14] Remove robieta from CODEOWNERS (#631) --- .github/CODEOWNERS | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 11c973b123..b75ae0a931 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -2,7 +2,7 @@ # These owners will be the default owners for everything in the repo. Unless a later match takes precedence, # @global-owner1, @global-owner2, and @global-owner3 will be requested for review when someone opens a pull request. -* @mruberry @lantiga @robieta @t-vi @carmocca +* @mruberry @lantiga @t-vi @carmocca # CI/CD and configs /.azure/ @borda @lantiga @t-vi @carmocca From c01ea88832e20787f872276e1487efaccf44762f Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Fri, 21 Jun 2024 17:46:11 +0900 Subject: [PATCH 08/14] Fix in-place to views condition (#633) --- thunder/core/transform_common.py | 23 ++++++++++++-- .../tests/test_inplace_functionalization.py | 31 ++++++++++++++++++- 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index 84a35485a2..3416a60d27 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -395,12 +395,31 @@ def check_inplace_to_views(computation_trace: Trace) -> None: producer_bsyms = producers(computation_trace) + # note(crcrpar): Why not using :func:`~thunder.core.symbol.has_tags`? + # Because it looks into `.sym.tags` of the input bsym and its subsymbols, + # thus even `ltorch.batch_norm` is regarded as `prims.OpTags.IN_PLACE`. + def has_tag(bsym: BoundSymbol, tag: prims.OpTags) -> bool: + return bsym.sym.tags and tag in bsym.sym.tags + + # note(crcrpar): Following ops would not look like a `prims.OpTags.SHAPE_OP` + # especially with respect to the relationship between the input and the output + # but some of their sub boundsymbols are. Therefore `thunder.core.symbol.gather_tags` gives it to them. + allowed_ltorch_ops = { + ltorch.batch_norm, + ltorch.avg_pool1d, + ltorch.avg_pool2d, + ltorch.avg_pool3d, + ltorch.max_pool1d, + ltorch.max_pool2d, + ltorch.max_pool3d, + } + bsym: BoundSymbol - for bsym in filter(lambda b: has_tags(b, {prims.OpTags.IN_PLACE}), computation_trace.bound_symbols): + for bsym in filter(lambda b: has_tag(b, prims.OpTags.IN_PLACE), computation_trace.bound_symbols): for in_tensor in filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_args): prod_bsym: BoundSymbol = producer_bsyms[in_tensor] utils.check( - not has_tags(prod_bsym, {prims.OpTags.SHAPE_OP}), + not has_tags(prod_bsym, {prims.OpTags.SHAPE_OP}) or prod_bsym.sym in allowed_ltorch_ops, lambda: f"in-place op to view tensors is not allowed but `{bsym.sym.id}` takes `{prod_bsym.sym.id}` output `{in_tensor}`", NotImplementedError, ) diff --git a/thunder/tests/test_inplace_functionalization.py b/thunder/tests/test_inplace_functionalization.py index b13299710f..5d2ab5c731 100644 --- a/thunder/tests/test_inplace_functionalization.py +++ b/thunder/tests/test_inplace_functionalization.py @@ -8,7 +8,7 @@ from thunder.core import dtypes from thunder.core.prims import PrimIDs -from thunder.tests.framework import ops +from thunder.tests.framework import ops, requiresCUDA from thunder.tests.opinfos import opinfos, OpInfo, make_number, SampleInput from thunder.tests.make_tensor import make_tensor from thunder.torch import _torch_to_thunder_function_map, _inplace_to_out_of_place @@ -147,3 +147,32 @@ def f_with_contiguous(a: torch.Tensor) -> torch.Tensor: jitted = thunder.jit(f_with_contiguous) with pytest.raises(NotImplementedError, match="in-place op to `torch.Tensor.contiguous`"): jitted(a) + + +# TODO(crcrpar): Investigate the numerical accuracy when `train=True` and dtype is fp32. +# with RTX6000 Ada and CUDA 12.3, I see somewhat huge error: +# E AssertionError: Tensor-likes are not close! +# E +# E Mismatched elements: 913 / 1000 (91.3%) +# E Greatest absolute difference: 0.000273287296295166 at index (0, 50) (up to 1e-05 allowed) +# E Greatest relative difference: 0.4177769422531128 at index (0, 727) (up to 1.3e-06 allowed) +@requiresCUDA +@pytest.mark.parametrize("train", (False, True)) +def test_parse_resnet18(train: bool): + import thunder + + torchvision = pytest.importorskip("torchvision") + + device = torch.device("cuda") + dtype = torch.float64 if train else torch.float32 + with device: + model: nn.Module = torchvision.models.resnet18(weights=None).to(device=device, dtype=dtype) + ref_model: nn.Module = torchvision.models.resnet18(weights=None).to(device=device, dtype=dtype) + if not train: + model = model.eval() + ref_model = ref_model.eval() + ref_model.load_state_dict(model.state_dict()) + + jitted = thunder.jit(model) + x = make_tensor((1, 3, 224, 224), dtype=dtype, device=device) + torch.testing.assert_close(jitted(x), ref_model(x)) From 9f9dcafc9ba5b07652bbab91a602aec3c628c8d1 Mon Sep 17 00:00:00 2001 From: Kaeun Kim <51257208+k223kim@users.noreply.github.com> Date: Fri, 21 Jun 2024 17:51:20 +0900 Subject: [PATCH 09/14] normalize test fix (#629) --- thunder/tests/opinfos.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 66c15e6ce1..bd5e5efeba 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -5931,11 +5931,14 @@ def normalize_sample_generator(op, device, dtype, requires_grad, **kwargs): (4, 2, 4, 5), ) for case in cases: - yield SampleInput(make(case), eps=1e-8) - yield SampleInput(make(case), p=0, eps=1e-8) - yield SampleInput(make(case), p=1, eps=1e-8) - yield SampleInput(make(case), p=4, eps=1e-8) - yield SampleInput(make(case), p=math.inf, eps=1e-8) + input_tensor = make(case) + # avoid very small norm tensors, which can be unstable to normalize + input_tensor = input_tensor + 0.2 * torch.sign(input_tensor) + yield SampleInput(input_tensor, eps=1e-8) + yield SampleInput(input_tensor, p=0, eps=1e-8) + yield SampleInput(input_tensor, p=1, eps=1e-8) + yield SampleInput(input_tensor, p=4, eps=1e-8) + yield SampleInput(input_tensor, p=math.inf, eps=1e-8) normalize_opinfo = OpInfo( From 2303b30344a8b1af9214010c96d545312a38fe13 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Sat, 22 Jun 2024 17:09:39 +0200 Subject: [PATCH 10/14] always return epilogue inputs from prologue (#636) --- thunder/__init__.py | 6 +++--- thunder/core/jit_ext.py | 11 +++-------- thunder/distributed/tensor_parallel/common.py | 2 +- thunder/distributed/transforms/fsdp_v2.py | 2 +- thunder/transforms/quantization.py | 15 ++++----------- 5 files changed, 12 insertions(+), 24 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 5026921140..23edf37b72 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -425,7 +425,7 @@ def get_computation_and_inputs(*args, **kwargs): ) = cache_entry try: cs.last_prologue_execution_start = time.time_ns() - if epilogue: + if interpretation is INTERPRETATION_OPTIONS.TRANSLATE_PYTHON: inps, pro_to_epi = pro(*args, **kwargs) else: inps = pro(*args, **kwargs) @@ -466,7 +466,7 @@ def get_computation_and_inputs(*args, **kwargs): ) = cache_entry cs.last_prologue_execution_start = time.time_ns() - if epilogue: + if interpretation is INTERPRETATION_OPTIONS.TRANSLATE_PYTHON: inps, pro_to_epi = pro(*args, **kwargs) else: inps = pro(*args, **kwargs) @@ -552,7 +552,7 @@ def get_computation_and_inputs(*args, **kwargs): cs.last_prologue_transformation_stop = time.time_ns() cs.last_prologue_execution_start = time.time_ns() - if epilogue: + if interpretation is INTERPRETATION_OPTIONS.TRANSLATE_PYTHON: inps, pro_to_epi = pro(*args, **kwargs) else: inps = pro(*args, **kwargs) diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 0ce1d51400..c9e9f1a2d0 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -1275,7 +1275,7 @@ def get_parameter_or_buffer_or_submodule_name_and_root(provenance): return typ, name, mprovenance -def unpack_inputs(ctx, prologue_trace, pro_to_comp_inps, pro_to_epi_inps, args, kwargs, *, has_epilogue: bool): +def unpack_inputs(ctx, prologue_trace, pro_to_comp_inps, pro_to_epi_inps, args, kwargs): already_unpacked: dict[int, Proxy] = {} orig_modules: dict[int, Proxy] = {} @@ -1523,10 +1523,7 @@ def from_provenance(provenance, *, new_output=False): else: raise NotImplementedError(f"cache info of type {type(v).__name__}") - if has_epilogue: - prims.python_return((pro_to_comp, pro_to_epi)) - else: - prims.python_return(pro_to_comp) + prims.python_return((pro_to_comp, pro_to_epi)) return pro_to_comp, pro_to_epi @@ -1692,9 +1689,7 @@ def thunder_general_jit( else: epilogue_trace = None - pro_to_comp_proxies, pro_to_epi_proxies = unpack_inputs( - ctx, prologue_trace, pro_to_comp, pro_to_epi, args, kwargs, has_epilogue=epilogue_trace is not None - ) + pro_to_comp_proxies, pro_to_epi_proxies = unpack_inputs(ctx, prologue_trace, pro_to_comp, pro_to_epi, args, kwargs) proxy_order = {id(p): i for i, p in enumerate(pro_to_comp_proxies)} pro_to_comp = tuple(sorted(pro_to_comp, key=lambda v: proxy_order[id(v.proxy)])) diff --git a/thunder/distributed/tensor_parallel/common.py b/thunder/distributed/tensor_parallel/common.py index 73072909cb..9b7770f94b 100644 --- a/thunder/distributed/tensor_parallel/common.py +++ b/thunder/distributed/tensor_parallel/common.py @@ -210,7 +210,7 @@ def transform_traces( prologue_producers, prologue_consumers = utils.producers_and_consumers(prologue_trace) pro_out_p: TensorProxy comp_inp_p: TensorProxy - for pro_out_p, comp_inp_p in zip(prologue_trace.output, computation_trace.args): + for pro_out_p, comp_inp_p in zip(prologue_trace.output[0], computation_trace.args): if pro_out_p.name not in self.chunked_param_name_to_layer_type: continue bsym = prologue_producers[pro_out_p] diff --git a/thunder/distributed/transforms/fsdp_v2.py b/thunder/distributed/transforms/fsdp_v2.py index fe96e1d9ea..71c51e7698 100644 --- a/thunder/distributed/transforms/fsdp_v2.py +++ b/thunder/distributed/transforms/fsdp_v2.py @@ -52,7 +52,7 @@ def transform_traces(self, prologue_trace, computation_trace, epilogue_trace, ** synchronized_parameters = [] param_name_to_comp_trc_proxy = {} # Track param_name to it's corresponding proxy in computation_trc. # todo: deal with epilogue output - for pro_out_p, comp_inp_p in zip(prologue_trace.output, computation_trace.args): + for pro_out_p, comp_inp_p in zip(prologue_trace.output[0], computation_trace.args): bsym = prologue_producers[pro_out_p] if bsym.sym == prims.unpack_parameter: param_thunder_module, param_name = bsym.args diff --git a/thunder/transforms/quantization.py b/thunder/transforms/quantization.py index ab86608d49..c62f90c16c 100644 --- a/thunder/transforms/quantization.py +++ b/thunder/transforms/quantization.py @@ -96,16 +96,9 @@ def transform_traces(self, prologue_trace, computation_trace, epilogue_trace, ** compute_producers, compute_consumers = utils.producers_and_consumers(computation_trace) - # This is needed because having epilogues adds an additional return tuple - # TODO: unify after https://github.com/Lightning-AI/lightning-thunder/issues/628 - if epilogue_trace is None: - prologue_to_epilogue_outputs = prologue_trace.output - output_subindex = None - else: - prologue_to_epilogue_outputs = prologue_trace.output[0] - output_subindex = 0 + proglogue_to_compute_outputs = prologue_trace.output[0] - output_idxes = {id(o): i for i, o in enumerate(prologue_to_epilogue_outputs)} + output_idxes = {id(o): i for i, o in enumerate(proglogue_to_compute_outputs)} computation_trace.push_scope([]) quantized_proxies: dict[int, str] = {} # id -> name @@ -147,8 +140,8 @@ def transform_traces(self, prologue_trace, computation_trace, epilogue_trace, ** # get_param.sym = unpack_buffer/parameter as needed new_bsyms.append(get_param.sym.bind(get_param.args[0], n_absmax, output=proxy_absmax)) new_bsyms.append(get_param.sym.bind(get_param.args[0], n_code, output=proxy_code)) - add_trace_output(prologue_trace, proxy_absmax, subindex=output_subindex) - add_trace_output(prologue_trace, proxy_code, subindex=output_subindex) + add_trace_output(prologue_trace, proxy_absmax, subindex=0) + add_trace_output(prologue_trace, proxy_code, subindex=0) new_compute_inputs.append(proxy_absmax) new_compute_inputs.append(proxy_code) qs["proxy_absmax"] = proxy_absmax From 4ce822b780afe9c95ab020498a5e35d0389368b4 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Sun, 23 Jun 2024 22:45:11 +0900 Subject: [PATCH 11/14] Partially support in-place ops and tensor aliases (#597) --- thunder/__init__.py | 8 +- thunder/core/transform_common.py | 114 +++++++++++++----- .../tests/test_inplace_functionalization.py | 110 +++++++++++++---- thunder/torch/__init__.py | 26 ++++ 4 files changed, 199 insertions(+), 59 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 23edf37b72..12f82953d3 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -510,9 +510,13 @@ def get_computation_and_inputs(*args, **kwargs): prologue_traces = [prologue_trc] computation_traces = [computation_trc] - check_inplace_to_views(computation_trc) + orig_to_view_swap_map = check_inplace_to_views(computation_trc) if not compile_options.get("skip_inplace_functionalization", False): - computation_traces.extend(functionalize_inplace_ops(computation_trace=computation_trc)) + computation_traces.extend( + functionalize_inplace_ops( + computation_trace=computation_trc, orig_to_view_swap_map=orig_to_view_swap_map + ) + ) computation_trc = computation_traces[-1] if epilogue_trc is not None: diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index 3416a60d27..122dfee68a 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -9,10 +9,10 @@ import thunder import thunder.core.prims as prims from thunder.core.baseutils import BoundSymbolInterface -from thunder.core.proxies import Proxy, variableify, Variable, TensorProxy +from thunder.core.proxies import Proxy, variableify, Variable, TensorProxy, unvariableify from thunder.core.pytree import tree_flatten, tree_map, tree_unflatten from thunder.core.symbol import BoundSymbol, BoundSymbolRHS, has_tags -from thunder.core.trace import from_trace, TraceProvenance, TraceCtx as Trace +from thunder.core.trace import from_trace, TraceProvenance, TraceCtx as Trace, tracectx from thunder.core.utils import ProxyDict, producers, check if TYPE_CHECKING: @@ -388,12 +388,17 @@ def transform_trace(self, computation_trace: Trace, **kwargs): pass -def check_inplace_to_views(computation_trace: Trace) -> None: - """Error out if ``computation_trace`` has any in-place op of `torch.reshape`'s output.""" +def check_inplace_to_views(computation_trace: Trace) -> dict[VariableInterface, TensorProxy]: + """Error out if in-place op that outputs of different number of elements from the input and the input has other consumers.""" from thunder.core import utils import thunder.torch as ltorch producer_bsyms = producers(computation_trace) + trace_args_set = ProxyDict() + for a in filter( + lambda a: isinstance(a, TensorProxy), tree_flatten((computation_trace.args, computation_trace.kwargs))[0] + ): + trace_args_set[a] = a # note(crcrpar): Why not using :func:`~thunder.core.symbol.has_tags`? # Because it looks into `.sym.tags` of the input bsym and its subsymbols, @@ -401,36 +406,52 @@ def check_inplace_to_views(computation_trace: Trace) -> None: def has_tag(bsym: BoundSymbol, tag: prims.OpTags) -> bool: return bsym.sym.tags and tag in bsym.sym.tags - # note(crcrpar): Following ops would not look like a `prims.OpTags.SHAPE_OP` - # especially with respect to the relationship between the input and the output - # but some of their sub boundsymbols are. Therefore `thunder.core.symbol.gather_tags` gives it to them. - allowed_ltorch_ops = { - ltorch.batch_norm, - ltorch.avg_pool1d, - ltorch.avg_pool2d, - ltorch.avg_pool3d, - ltorch.max_pool1d, - ltorch.max_pool2d, - ltorch.max_pool3d, - } - + swap_map: dict[VariableInterface, TensorProxy] = {} + consumers = utils.consumers(computation_trace) bsym: BoundSymbol for bsym in filter(lambda b: has_tag(b, prims.OpTags.IN_PLACE), computation_trace.bound_symbols): - for in_tensor in filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_args): - prod_bsym: BoundSymbol = producer_bsyms[in_tensor] - utils.check( - not has_tags(prod_bsym, {prims.OpTags.SHAPE_OP}) or prod_bsym.sym in allowed_ltorch_ops, - lambda: f"in-place op to view tensors is not allowed but `{bsym.sym.id}` takes `{prod_bsym.sym.id}` output `{in_tensor}`", - NotImplementedError, - ) - utils.check( - prod_bsym.sym != ltorch.contiguous, - lambda: f"in-place op to `torch.Tensor.contiguous` output is not allowed but `{bsym.sym.id}` takes `{prod_bsym.sym.id}` output `{in_tensor}`", - NotImplementedError, - ) + in_tensor: TensorProxy = list(filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_args))[0] + if in_tensor in trace_args_set: + continue + prod_bsym: BoundSymbol = producer_bsyms[in_tensor] + orig_tensor = prod_bsym.flat_proxy_args[0] + consumer_of_orig_tensor = consumers[orig_tensor] + # When the orig tensor is not used by consumers other than `prod_bsym`, it'd be safe. + # Otherwise, we'd need to replace the use of ``orig_tensor`` with a view, unless the original + # is an arg or a kwarg. + if len(consumer_of_orig_tensor) == 1: + continue -def functionalize_inplace_ops(computation_trace: Trace) -> list[Trace]: + utils.check( + prod_bsym.sym not in ltorch._syms_returning_runtime_dependently_views, + lambda: ( + f"in-place op of `{bsym.sym.id}` to `{prod_bsym.sym.id}` output `{in_tensor}` is not " + f"supported. It's unclear if the output of " + f"{tuple(s.id for s in ltorch._syms_returning_runtime_dependently_views)} is " + f"a copy, a view, or the input itself, as per https://pytorch.org/docs/stable/tensor_view.html" + ), + NotImplementedError, + ) + if prod_bsym.sym not in ltorch._syms_returning_views: + continue + + utils.check( + orig_tensor.numel == in_tensor.numel, + lambda: ( + f"in-place op of `{bsym.sym.id}` to `{in_tensor}`, a view tensor of " + f"`{orig_tensor}` is not supported because {in_tensor.numel} != {orig_tensor.numel}" + ), + NotImplementedError, + ) + + swap_map[variableify(orig_tensor)] = in_tensor + return swap_map + + +def functionalize_inplace_ops( + computation_trace: Trace, orig_to_view_swap_map: dict[VariableInterface, TensorProxy] +) -> list[Trace]: """Functionalize in-place ops in ``computation_trace``. In thunder, an in-place is an out-of-place or functional op followed by :func:`~thunder.core.prims.copy_`. @@ -459,9 +480,28 @@ def is_functionalizable(bsym: BoundSymbol) -> bool: bsym: BoundSymbol swap_map: dict[VariableInterface, ProxyInterface] = {} bsyms: list[BoundSymbol] = [] + tensors_observed: set[VariableInterface] = set() for bsym in computation_trace.bound_symbols: new_bsym = bsym.from_bsym_swap_proxies(swap_map) + cur_orig_to_view_swap_map: dict[VariableInterface, TensorProxy] = {} + for t in filter(lambda p: isinstance(p, TensorProxy), new_bsym.flat_args): + if (var_t := variableify(t)) not in tensors_observed: + tensors_observed.add(var_t) + else: + if var_t in orig_to_view_swap_map: + var_view_t = variableify(orig_to_view_swap_map[var_t]) + check(var_view_t in swap_map, lambda: f"{var_view_t} not in {swap_map}, {orig_to_view_swap_map = }") + cur_orig_to_view_swap_map[var_t] = swap_map[var_view_t] + if cur_orig_to_view_swap_map: + with tracectx(computation_trace): + for var_orig, view in cur_orig_to_view_swap_map.items(): + view_of_orig_shape = prims.reshape.meta(view, unvariableify(var_orig).shape) + reshape_bsym = prims.reshape.bind(view, unvariableify(var_orig).shape, output=view_of_orig_shape) + cur_orig_to_view_swap_map[var_orig] = view_of_orig_shape + bsyms.append(reshape_bsym) + new_bsym = bsym.from_bsym_swap_proxies(cur_orig_to_view_swap_map, skip_output=True) + # in-place functionalizable ops has `prims.copy_` as the last subsymbol. if not is_functionalizable(new_bsym): bsyms.append(new_bsym) @@ -476,12 +516,20 @@ def is_functionalizable(bsym: BoundSymbol) -> bool: bsyms.append(new_bsym) intermediate_trace = from_trace(computation_trace) - intermediate_trace.bound_symbols = bsyms[:] + intermediate_trace.bound_symbols = bsyms intermediate_trace.set_provenance(TraceProvenance("Intermediate trace of `functionalize_inplace_ops`")) - del bsyms + + intermediate_trace.bound_symbols[-1] = intermediate_trace.bound_symbols[-1].from_bsym_swap_proxies(swap_map) + return_bsym = intermediate_trace.bound_symbols[-1] + for t in filter(lambda p: isinstance(p, TensorProxy), return_bsym.flat_args): + check( + (var_t := variableify(t)) not in swap_map, + lambda: f"{return_bsym.flat_args=}. `{t}` should have been replaced by `{swap_map[var_t]}`, {new_return_bsym=}", + ) # Step 2: Remove `prims.copy_` if it's the last one of `bsym.subsymbols`, # unless `copy_to` is `computation_trace.args` or `computation_trace.kwargs` + producer_map = producers(intermediate_trace) trace_args_set = ProxyDict() for a in filter( lambda a: isinstance(a, TensorProxy), tree_flatten((computation_trace.args, computation_trace.kwargs))[0] @@ -489,6 +537,7 @@ def is_functionalizable(bsym: BoundSymbol) -> bool: trace_args_set[a] = a bsym_inplace_to_functional = {} swap_map.clear() + new_bsyms: list[BoundSymbol] = [] for bsym in intermediate_trace.bound_symbols: new_bsym = bsym.from_bsym_swap_proxies(swap_map) @@ -496,6 +545,7 @@ def is_functionalizable(bsym: BoundSymbol) -> bool: if not is_functionalizable(new_bsym): new_bsyms.append(new_bsym) continue + copy_bsym = bsym.subsymbols[-1] copy_return = copy_bsym.flat_proxy_outs[0] copy_from = copy_bsym.flat_proxy_args[0] diff --git a/thunder/tests/test_inplace_functionalization.py b/thunder/tests/test_inplace_functionalization.py index 5d2ab5c731..0eb591c2f8 100644 --- a/thunder/tests/test_inplace_functionalization.py +++ b/thunder/tests/test_inplace_functionalization.py @@ -8,7 +8,7 @@ from thunder.core import dtypes from thunder.core.prims import PrimIDs -from thunder.tests.framework import ops, requiresCUDA +from thunder.tests.framework import instantiate, ops, requiresCUDA, NOTHING from thunder.tests.opinfos import opinfos, OpInfo, make_number, SampleInput from thunder.tests.make_tensor import make_tensor from thunder.torch import _torch_to_thunder_function_map, _inplace_to_out_of_place @@ -125,30 +125,6 @@ def test_functionalization(op: OpInfo, device: str, dtype: dtypes.dtype, executo ) -def test_invalid_cases(): - import thunder - - a = torch.randn((2, 2)) - - def f_with_reshape(a: torch.Tensor) -> torch.Tensor: - b = torch.reshape(a, (-1,)) - b.exp_() - return b - - jitted = thunder.jit(f_with_reshape) - with pytest.raises(NotImplementedError, match="in-place op to view tensors is not allowed but"): - jitted(a) - - def f_with_contiguous(a: torch.Tensor) -> torch.Tensor: - b = a.contiguous() - b.exp_() - return b - - jitted = thunder.jit(f_with_contiguous) - with pytest.raises(NotImplementedError, match="in-place op to `torch.Tensor.contiguous`"): - jitted(a) - - # TODO(crcrpar): Investigate the numerical accuracy when `train=True` and dtype is fp32. # with RTX6000 Ada and CUDA 12.3, I see somewhat huge error: # E AssertionError: Tensor-likes are not close! @@ -176,3 +152,87 @@ def test_parse_resnet18(train: bool): jitted = thunder.jit(model) x = make_tensor((1, 3, 224, 224), dtype=dtype, device=device) torch.testing.assert_close(jitted(x), ref_model(x)) + + +@instantiate( + dtypes=NOTHING, +) +def test_inplace_to_views(executor, device, _): + import thunder + + def f(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + c = torch.exp(a) + d = torch.tanh(b) + + e = c.view(-1) + e += d.flatten() + + d.div_(a) + return c, d, e + + a, b = (make_tensor((2, 2), device=device, dtype=torch.float32) for _ in range(2)) + a_, b_ = a.clone().detach(), b.clone().detach() + + jittd_f = thunder.jit(f, executors=executor.executors_list()) + + c, d, e = jittd_f(a, b) + c_, d_, e_ = f(a_, b_) + + torch.testing.assert_close((c, d, e), (c_, d_, e_)) + + def g(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + c = torch.exp(a) + d = torch.tanh(b) + + e, _ = c.chunk(2) + e *= 1.5 + + d.div_(a) + return d, e + + a, b = (make_tensor((2, 2), device=device, dtype=torch.float32) for _ in range(2)) + a_, b_ = a.clone().detach(), b.clone().detach() + + jittd_g = thunder.jit(g, executors=executor.executors_list()) + + d, e = jittd_g(a, b) + d_, e_ = g(a_, b_) + + torch.testing.assert_close((d, e), (d_, e_)) + + +@instantiate( + dtypes=NOTHING, +) +def test_error_of_inplace_to_views(executor, device, _): + import thunder + + def f(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + c = torch.exp(a) + d = torch.tanh(b) + + e = c.flatten() + e += d.flatten() + + d.div_(a) + return c, d, e + + a, b = (make_tensor((2, 2), device=device, dtype=torch.float32) for _ in range(2)) + jittd_f = thunder.jit(f, executors=executor.executors_list()) + + with pytest.raises(NotImplementedError, match="in-place op of `torch.Tensor.add_` to `torch.flatten` output"): + _ = jittd_f(a, b) + + def f(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + c = torch.exp(a) + d = torch.tanh(b) + + e, _ = c.chunk(2) + e *= 1.5 + + d.div_(a) + return c, d, e + + jittd_f = thunder.jit(f, executors=executor.executors_list()) + with pytest.raises(NotImplementedError, match="in-place op of `torch.Tensor.mul_`"): + _ = jittd_f(a, b) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 50466e4afd..4e1cce9fcf 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -5079,3 +5079,29 @@ def reduce_scatter( **_torch_to_thunder_function_map, **{fn: fn for fn in _torch_noinline_functions}, } + + +# ref: https://pytorch.org/docs/stable/tensor_view.html +_syms_returning_runtime_dependently_views: set[Symbol] = {reshape, contiguous, to, flatten} + +_syms_returning_views: set[Symbol] = { + diagonal, + expand, + expand_as, + movedim, + permute, + select, + squeeze, + transpose, + t, + real, + unflatten, + unfold, + unsqueeze, + view, + view_as, + unbind, + split, + tensor_split, + chunk, +} From f334d39b20f662a59df9115f63295d1c8248c5fe Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Mon, 24 Jun 2024 11:22:57 +0200 Subject: [PATCH 12/14] reenable cudnn sdpa (#639) --- thunder/tests/opinfos.py | 9 --------- thunder/tests/test_cudnn_executor.py | 4 ---- thunder/tests/test_grad.py | 9 --------- 3 files changed, 22 deletions(-) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index bd5e5efeba..3174a949f3 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -7880,15 +7880,6 @@ def grad_scaled_dot_product_attention_sample_generator(op, device, dtype, requir # NOTE: NotImplementedError: Could not run 'aten::_scaled_dot_product_efficient_attention' with arguments from the 'CPU' backend. # NOTE: NotImplementedError: Could not run 'aten::_scaled_dot_product_efficient_attention_backward' with arguments from the 'CPU' backend devicetypes=(devices.DeviceType.CUDA,), - test_directives=( - DecorateInfo( - pytest.mark.skip(reason="https://github.com/Lightning-AI/lightning-thunder/issues/567"), - "test_core_vs_torch_consistency", - dtypes=(datatypes.bfloat16, datatypes.float16, datatypes.float32), - devicetypes=(devices.DeviceType.CUDA,), - active_if=version_between(torch.__version__, min_ver="2.4.0a0", max_ver="2.4.0a99"), - ), - ), ) nn_ops.append(grad_sdpa_opinfo) diff --git a/thunder/tests/test_cudnn_executor.py b/thunder/tests/test_cudnn_executor.py index 5c05afb5b7..9196fa0721 100644 --- a/thunder/tests/test_cudnn_executor.py +++ b/thunder/tests/test_cudnn_executor.py @@ -200,10 +200,6 @@ def test_cudnn_vs_torch_consistency(op, device, dtype, *_): LooseVersion(cudnn.backend_version_string()) < LooseVersion("8.9.5"), reason="cuDNN is required to be at least `8.9.5`", ) -@pytest.mark.skipif( - version_between(torch.__version__, min_ver="2.4.0a0", max_ver="2.4.0a99"), - reason="https://github.com/Lightning-AI/lightning-thunder/issues/567", -) @pytest.mark.parametrize("may_cat_grad_qkv", (True, False), ids=("may-cat-grad-qkv", "never-cat-grad-qkv")) @pytest.mark.parametrize("dtype", grad_sdpa_cudnn_opinfo.dtypes(), ids=tuple(map(str, grad_sdpa_cudnn_opinfo.dtypes()))) def test_vjp_correctness_cudnn_sdpa(dtype, may_cat_grad_qkv): diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index 8a99ef0c41..ab1948a692 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -541,21 +541,12 @@ def test_vjp_correctness_index_put_manual(op, device, dtype, executor, comp): # NOTE Scaled_Dot_Product_Efficient_Attention_Backward does not support fp64 dtypes # RuntimeError: Only fp32, half & bf16 supported at the moment -@pytest.mark.skipif( - not version_between(torch.__version__, min_ver="2.4.0a0", max_ver="2.4.0a99"), - reason="https://github.com/Lightning-AI/lightning-thunder/issues/567", -) @ops( (get_opinfo("grad_forward_scaled_dot_product_attention"),), supported_dtypes=(dtypes.float16, dtypes.bfloat16), supported_devicetypes=(devices.DeviceType.CUDA,), ) def test_vjp_correctness_sdpa_manual(op, device, dtype, executor, comp): - if version_between(torch.__version__, min_ver="2.4.0a0", max_ver="2.4.0a99"): - raise pytest.skip( - "https://github.com/Lightning-AI/lightning-thunder/issues/567", - ) - for sample in op.sample_inputs(device, dtype, requires_grad=True): from thunder.executors.sdpaex import sdpa_ex From 18899a85cbcaf1f28f24110b2f11e382d9fd4124 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 24 Jun 2024 02:27:44 -0700 Subject: [PATCH 13/14] Allowing static constraint in torch/__init__.py (#613) --- thunder/core/prims.py | 10 ++++++++++ thunder/core/proxies.py | 1 + thunder/executors/pythonex.py | 8 ++++++++ thunder/tests/test_jit_general.py | 15 +++++++++++++++ thunder/torch/__init__.py | 4 ++++ 5 files changed, 38 insertions(+) diff --git a/thunder/core/prims.py b/thunder/core/prims.py index c0c11ab6f3..d84536ff7f 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -261,6 +261,8 @@ class PrimIDs(Enum): # Memory access methods ITEM = auto() COPY_ = auto() + # + SINK = auto() class OpTags(Enum): @@ -3888,3 +3890,11 @@ def copy__meta( copy_ = make_prim(PrimIDs.COPY_, "copy_", meta=copy__meta, tags=(OpTags.DONT_DCE,)) + + +def sink_meta(*args, **kwargs): + return + + +# TODO do we want another tag to remove this after prologue is constructed? +sink = make_prim(PrimIDs.SINK, "sink", meta=sink_meta, tags=(OpTags.DONT_DCE,)) diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index 786163cbb5..c0ff0e90f1 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -612,6 +612,7 @@ def known_value(self) -> bool: def make_static_constrained(self): baseutils.check(self.constraint != CONSTRAINT.DYNAMIC, lambda: f"dynamic NumberProxy cannot be made static") + baseutils.check(self.value is not None, lambda: f"static NumberProxy needs to have value") self.constraint = CONSTRAINT.STATIC def make_constrainable(self): diff --git a/thunder/executors/pythonex.py b/thunder/executors/pythonex.py index 9235f20c1c..227a7a1ac5 100644 --- a/thunder/executors/pythonex.py +++ b/thunder/executors/pythonex.py @@ -365,5 +365,13 @@ def _elementwise_binary_checker(a: NumberLike | TensorProxy, b: NumberLike | Ten ex.register_implementation(prims.sub, sub, checker=_elementwise_binary_checker) ex.register_implementation(prims.div, div, checker=_elementwise_binary_checker) + +def _sink(*args, **kwargs): + return + + +sink = ex.register_operator("sink", like=prims.sink, fn=_sink) +ex.register_implementation(prims.sink, sink, checker=_always_executable) + # TODO: Restore truediv once we find it... # ex.register_implementation(prims.truediv, truediv, checker=_elementwise_binary_checker) diff --git a/thunder/tests/test_jit_general.py b/thunder/tests/test_jit_general.py index da7fe97c47..d6c526ebc8 100644 --- a/thunder/tests/test_jit_general.py +++ b/thunder/tests/test_jit_general.py @@ -955,3 +955,18 @@ def bar(t): jbar = thunder.jit(bar, cache="symbolic values") t = torch.randn(4, device="cpu") jbar(t) + + +def test_cache_symbolic_values_torch_device(): + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + def foo(dev, idx): + # NOTE dtype needs to be explicit, see issue: https://github.com/Lightning-AI/lightning-thunder/issues/621 + return torch.ones(1, device=torch.device(dev, idx), dtype=torch.float32) + + jfoo = thunder.jit(foo, cache="symbolic values") + expected = foo("cuda", 0) + actual = jfoo("cuda", 0) + + assert_close(expected, actual) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 4e1cce9fcf..ffed534449 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -4906,6 +4906,10 @@ def torch_device(device_or_str: DeviceLike, /, index: int | None = None) -> devi not has_device_idx, lambda: f"device string must not include an index because index was passed explicitly: {device_or_str}", ) + if isinstance(index, NumberProxy): + index.make_static_constrained() + prims.sink(index) + index = index.value return devices.Device(device_or_str, index) From fa55b090d88d769b926f67ea71087c5660f730bc Mon Sep 17 00:00:00 2001 From: Kaeun Kim <51257208+k223kim@users.noreply.github.com> Date: Mon, 24 Jun 2024 18:46:55 +0900 Subject: [PATCH 14/14] recursive error fix (#626) --- thunder/core/codeutils.py | 1 - thunder/core/pytree.py | 30 +++++++++++++++++++++++++++++- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/thunder/core/codeutils.py b/thunder/core/codeutils.py index 3eaca625a3..4c631756eb 100644 --- a/thunder/core/codeutils.py +++ b/thunder/core/codeutils.py @@ -123,7 +123,6 @@ def to_printable( if is_collection(x): flat, spec = tree_flatten(x) - printables = [] for f in flat: printables.append(to_printable(trace, f, import_ctx=import_ctx, object_ctx=object_ctx)) diff --git a/thunder/core/pytree.py b/thunder/core/pytree.py index 96a4a322cb..15b6011621 100644 --- a/thunder/core/pytree.py +++ b/thunder/core/pytree.py @@ -2,6 +2,9 @@ import optree import torch +import thunder.core.dtypes as dtypes +import thunder.core.devices as devices +from thunder.core.baseutils import ProxyInterface # We need torch.Size to be treated the same way as a list or tuple # In PyTorch this is registered here: @@ -13,7 +16,32 @@ namespace=optree.registry.__GLOBAL_NAMESPACE, ) -tree_flatten = partial(optree.tree_flatten, none_is_leaf=True) + +def tree_flatten(args): + if type(args) not in { + dict, + list, + str, + int, + bool, + tuple, + torch.dtype, + float, + dtypes.floating, + dtypes.bool_, + devices.Device, + torch.memory_format, + type(None), + slice, + complex, + type, + type(Ellipsis), + torch.Size, + } and not isinstance(args, (ProxyInterface)): + raise TypeError(f"tree_flatten of type {type(args)} is not supported.") + return optree.tree_flatten(args, none_is_leaf=True) + + tree_map = partial(optree.tree_map, none_is_leaf=True)