From b873afa84a9b1829f32ca8aa118191edc2da228a Mon Sep 17 00:00:00 2001 From: Kshiteej K Date: Mon, 25 Mar 2024 16:44:38 +0100 Subject: [PATCH 1/6] transformer_engine: wrap checker_fn in langctx and cleanup (PR2473) (#24) Co-authored-by: Thomas Viehmann --- thunder/executors/transformer_engineex.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/thunder/executors/transformer_engineex.py b/thunder/executors/transformer_engineex.py index ced5d8fdb1..ba53c3e940 100644 --- a/thunder/executors/transformer_engineex.py +++ b/thunder/executors/transformer_engineex.py @@ -20,6 +20,7 @@ from thunder.core.proxies import TensorProxy, CollectionProxy from thunder.core.symbol import Symbol from thunder.extend import OperatorExecutor, register_executor +from thunder.core.langctxs import langctx, Languages __all__ = [ "transformer_engine_ex", @@ -369,6 +370,10 @@ def bind_postprocess(bsym: BoundSymbol) -> None: # # Registers transformer_engine_ex as an executor for torch.nn.functional.linear # + + +# NOTE: We need langctx so that we can resolve `view` on TensorProxy. +@langctx(Languages.TORCH) def _linear_checker( a: TensorProxy, w: TensorProxy, @@ -398,15 +403,6 @@ def linear_forwad_rule(a, w, bias): return primal, saved_for_backward -def linear_forward_rule_checker(a: TensorProxy, w: TensorProxy, bias: None | TensorProxy) -> bool: - from thunder.core.compile_data import get_compile_data - - cd = get_compile_data() - if transformer_engine_ex in cd.executors_list: - return _linear_checker(a, w, bias) - return False - - def linear_backward_rule(a_shape, w_shape, b_shape, ctx_idx, grad): return te_functional_linear_backward(grad, a_shape, w_shape, b_shape, ctx_idx) From 0f03314b325a3d8480893813567a2cdbb3a80819 Mon Sep 17 00:00:00 2001 From: nikitaved Date: Mon, 25 Mar 2024 17:05:16 +0100 Subject: [PATCH 2/6] `collections.namedtuple`: add lookaside (#47) --- thunder/core/interpreter.py | 94 ++++++++++++++++++++++++++++++- thunder/tests/test_interpreter.py | 50 ++++++++++++++++ 2 files changed, 141 insertions(+), 3 deletions(-) diff --git a/thunder/core/interpreter.py b/thunder/core/interpreter.py index aea2baf52e..7b8c146531 100644 --- a/thunder/core/interpreter.py +++ b/thunder/core/interpreter.py @@ -404,6 +404,10 @@ def __init__( self._uncacheable_classes = uncacheable_classes + @property + def with_provenance_tracking(self): + return self._with_provenance_tracking + def interpret(self, inst: dis.Instruction, /, **interpreter_state) -> None | int | INTERPRETER_SIGNALS: return self._opcode_interpreter(inst, **interpreter_state) @@ -887,6 +891,7 @@ class PseudoInst(str, enum.Enum): BINARY_SUBSCR = "BINARY_SUBSCR" BUILD_DICT = "BUILD_DICT" BUILD_TUPLE = "BUILD_TUPLE" + BUILD_NAMEDTUPLE = "BUILD_NAMEDTUPLE" CONSTANT = "CONSTANT" EXCEPTION_HANDLER = "EXCEPTION_HANDLER" INPUT_ARGS = "INPUT_ARGS" @@ -2589,6 +2594,55 @@ def impl(self, other): return _interpret_call(impl, self, other) +def _collections_namedtuple_lookaside( + typename: str, + field_names: Iterable[str], + *, + rename: bool = False, + defaults: None | Iterable[Any] = None, + module: None | str = None, +): + # Type checks { + assert wrapped_isinstance(typename, str) + assert wrapped_isinstance(field_names, Iterable) + assert wrapped_isinstance(rename, bool) + if defaults is not None: + assert wrapped_isinstance(defaults, Iterable) + if module is not None: + assert wrapped_isinstance(module, str) + # } + + # Wrap defaults { + if not isinstance(rename, WrappedValue): + rename = wrap_const(rename) + + if defaults is None: + defaults = wrap_const(defaults) + + if module is None: + # To prevent taking module from the direct caller, + # we use the module's name from the active frame + curr_frame = get_interpreterruntimectx().frame_stack[-1] + module = unwrap(curr_frame.globals).get("__name__", None) + module = wrap_const(module) + # } + + # Run opaque namedtuple { + @interpreter_needs_wrap + def create_namedtuple(typename: str, field_names: str, **kwargs): + namedtuple_type = collections.namedtuple(typename, field_names, **kwargs) + return namedtuple_type + + namedtuple_type = create_namedtuple(typename, field_names, rename=rename, defaults=defaults, module=module) + if namedtuple_type is INTERPRETER_SIGNALS.EXCEPTION_RAISED: + return namedtuple_type + + assert wrapped_isinstance(namedtuple_type, type) + # } + + return namedtuple_type + + _default_lookaside_map: dict[Callable, Callable] = { # Jit lookasides is_jitting: _is_jitting_lookaside, @@ -2612,6 +2666,7 @@ def impl(self, other): isinstance: _isinstance_lookaside, functools.reduce: _functools_reduce_lookaside, operator.getitem: _getitem_lookaside, + collections.namedtuple: _collections_namedtuple_lookaside, } @@ -2619,9 +2674,11 @@ def impl(self, other): # immutuable sequences (tuples) are created with contents in __new__ and __init__ is a nop # (object.__init__, actually). def _tuple_new_provenance_tracking_lookaside(cls, iterable=(), /): + new_tuple_type = cls.value + assert issubclass(new_tuple_type, tuple) + if iterable == (): iterable = wrap_const(()) - assert cls.value is tuple if isinstance(iterable.value, (list, tuple)): # special case to avoid infinite recursion @@ -2648,8 +2705,39 @@ def _tuple_new_provenance_tracking_lookaside(cls, iterable=(), /): else: item_wrappers.append(wv) - ures = tuple(w.value for w in item_wrappers) - pr = ProvenanceRecord(PseudoInst.BUILD_TUPLE, inputs=[w.provenance for w in item_wrappers]) + def is_likely_from_collections_namedtuple(tuple_type): + from collections import namedtuple + + # Check if tuple_type code object is coming from namedtuple + return ( + hasattr(tuple_type, "__repr__") + and hasattr(tuple_type.__repr__, "__code__") + and tuple_type.__repr__.__code__ in namedtuple.__code__.co_consts + ) + + # Construction of namedtuples may raise + try: + ures = tuple(w.value for w in item_wrappers) + # Named tuples expect varargs, not iterables at new/init + if is_likely_from_collections_namedtuple(new_tuple_type): + if hasattr(new_tuple_type, "__bases__") and new_tuple_type.__bases__ == (tuple,): + ures = new_tuple_type(*ures) + build_inst = PseudoInst.BUILD_NAMEDTUPLE + else: + return do_raise( + NotImplementedError( + f"The type {new_tuple_type} is likely a subclassed named tuple. " + "Subclassing the types returned by `collections.namedtuple` " + "is currently not supported! Please, file an issue requesting this support." + ) + ) + else: + ures = new_tuple_type(ures) + build_inst = PseudoInst.BUILD_TUPLE + except Exception as e: + return do_raise(e) + + pr = ProvenanceRecord(build_inst, inputs=[w.provenance for w in item_wrappers]) res = wrap(ures, provenance=pr) res.item_wrappers = item_wrappers diff --git a/thunder/tests/test_interpreter.py b/thunder/tests/test_interpreter.py index ab1b45289c..0170997d24 100644 --- a/thunder/tests/test_interpreter.py +++ b/thunder/tests/test_interpreter.py @@ -1030,6 +1030,56 @@ def add(x, y): assert jfoo((1, 2, 3), jadd) == 6 +def test_namedtuple_lookaside(jit): + from collections import namedtuple + + typename = "MyNamedTuple" + field_names = ("a", "b", "c") + + # Test returnign just the type { + def f(): + return namedtuple(typename, field_names) + + jf = jit(f) + + jtype = jf() + assert isinstance(jtype, type) + assert jtype.__name__ == typename + assert all(hasattr(jtype, field) for field in field_names) + + # Check module name + import inspect + + assert jtype.__module__ == inspect.currentframe().f_globals["__name__"] + # } + + # Test accessing elements { + a = torch.rand(1) + b = torch.rand(1) + c = torch.rand(1) + + def f(a, b, c): + nt = namedtuple(typename, field_names) + obj = nt(a, b, c) + return obj[0] + + jf = jit(f) + + assert f(a, b, c) is a + assert jf(a, b, c) is a + + def f(a, b, c): + nt = namedtuple(typename, field_names) + obj = nt(a, b, c) + return obj.a + + jf = jit(f) + + assert f(a, b, c) is a + assert jf(a, b, c) is a + # } + + def test_calling_methods(jit): jitting = False From 37129cf607c2bd3c56a2e1bdd0ec074a757ed1a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 25 Mar 2024 17:14:56 +0100 Subject: [PATCH 3/6] Update cc-bot issue number (#73) --- .github/lightning-probot.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/lightning-probot.yml b/.github/lightning-probot.yml index 9398721547..80cf9b671e 100644 --- a/.github/lightning-probot.yml +++ b/.github/lightning-probot.yml @@ -1 +1 @@ -tracking_issue: 1464 +tracking_issue: 72 From be16444471352eeb1bd11c468efecb2c3b91e5f1 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 25 Mar 2024 18:30:06 +0200 Subject: [PATCH 4/6] Rename thunder-fwd-bwd->thunder for test_llama2_qkv_split_rope_7b_train (#75) --- thunder/benchmarks/targets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/thunder/benchmarks/targets.py b/thunder/benchmarks/targets.py index 9b38e26c3d..7bbe14298a 100644 --- a/thunder/benchmarks/targets.py +++ b/thunder/benchmarks/targets.py @@ -873,8 +873,8 @@ def test_llama2_7b_rmsnorm_grad(benchmark, executor: Callable): ids=( "torch", "torch.compile", - "thunder-fwd-bwd", - "thunder+nvfuser+torch.compile-fwd-bwd", + "thunder", + "thunder+nvfuser+torch.compile", "torch+apex", "torch.compile+apex", ), From f2e48b39eda18713439891da9fcbf0338e86e4d9 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Mon, 25 Mar 2024 12:33:30 -0400 Subject: [PATCH 5/6] Add custom CUDA kernels to README (#71) --- README.md | 3 ++- pyproject.toml | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index ce5399d570..a672abe48e 100644 --- a/README.md +++ b/README.md @@ -154,7 +154,8 @@ Thunder doesn't generate code for accelerators directly. It acquires and transfo - [Apex](https://github.com/NVIDIA/apex) - [TransformerEngine](https://github.com/NVIDIA/TransformerEngine) - [PyTorch eager](https://github.com/pytorch/pytorch) -- custom kernels, including those written with [OpenAI Triton](https://github.com/openai/triton) +- Custom CUDA kernels through [PyCUDA](https://documen.tician.de/pycuda/tutorial.html#interoperability-with-other-libraries-using-the-cuda-array-interface), [Numba](https://numba.readthedocs.io/en/stable/cuda/kernels.html), [CuPy](https://docs.cupy.dev/en/stable/user_guide/kernel.html) +- Custom kernels written in [OpenAI Triton](https://github.com/openai/triton) Modules and functions compiled with Thunder fully interoperate with vanilla PyTorch and support PyTorch's autograd. Also, Thunder works alongside torch.compile to leverage its state-of-the-art optimizations. diff --git a/pyproject.toml b/pyproject.toml index 520a661279..fc03746330 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,8 @@ quiet-level = 3 # https://github.com/codespell-project/codespell/issues/2839#issuecomment-1731601603 # also adding links until they ignored by its: nature # https://github.com/codespell-project/codespell/issues/2243#issuecomment-1732019960 -#ignore-words-list = "" +# documen is used in an url in README +ignore-words-list = "documen" [tool.black] From 185743614c05e27a5e4531cfbf41d40cf4cd5574 Mon Sep 17 00:00:00 2001 From: Vedaanta Agarwalla <142048820+vedaanta-nvidia@users.noreply.github.com> Date: Tue, 26 Mar 2024 02:31:32 -0700 Subject: [PATCH 6/6] bumps cudnnex to 1.2.1 (PR2485) (#36) Co-authored-by: Vedaanta Agarwalla Co-authored-by: Jirka Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- .azure/docker-build.yml | 8 ++------ .azure/gpu-tests.yml | 8 ++++---- dockers/ubuntu-cuda/Dockerfile | 2 +- docs/source/fundamentals/installation.rst | 2 +- 4 files changed, 8 insertions(+), 12 deletions(-) diff --git a/.azure/docker-build.yml b/.azure/docker-build.yml index 73233ae78c..bbef960124 100644 --- a/.azure/docker-build.yml +++ b/.azure/docker-build.yml @@ -40,14 +40,10 @@ jobs: #maxParallel: "3" matrix: # CUDA 12.1 - 'cuda 12.1 | torch 2.2 | cudnn FE v1.1': # todo: drop updating this image when CI transition to newer FE version - {CUDA_VERSION: '12.1.1', TORCH_VERSION: '2.2.1', TRITON_VERSION: '2.2.0', CUDNN_FRONTEND: "1.1.0"} 'cuda 12.1 | torch 2.2 | cudnn FE v1.2': - {CUDA_VERSION: '12.1.1', TORCH_VERSION: '2.2.1', TRITON_VERSION: '2.2.0', CUDNN_FRONTEND: "1.2.0"} - 'cuda 12.1 | torch 2.3 /nightly | cudnn FE v1.1': # todo: drop updating this image when CI transition to newer FE version - {CUDA_VERSION: '12.1.1', TORCH_VERSION: 'main', TORCH_INSTALL: 'source', CUDNN_FRONTEND: "1.1.0"} + {CUDA_VERSION: '12.1.1', TORCH_VERSION: '2.2.1', TRITON_VERSION: '2.2.0', CUDNN_FRONTEND: "1.2.1"} 'cuda 12.1 | torch 2.3 /nightly | cudnn FE v1.2': - {CUDA_VERSION: '12.1.1', TORCH_VERSION: 'main', TORCH_INSTALL: 'source', CUDNN_FRONTEND: "1.2.0"} + {CUDA_VERSION: '12.1.1', TORCH_VERSION: 'main', TORCH_INSTALL: 'source', CUDNN_FRONTEND: "1.2.1"} #'cuda 12.1': # this version - '8.9.5.29-1+cuda12.1' for 'libcudnn8' was not found # how much time to give 'run always even if cancelled tasks' before stopping them cancelTimeoutInMinutes: "2" diff --git a/.azure/gpu-tests.yml b/.azure/gpu-tests.yml index 049a372942..d2c251ef98 100644 --- a/.azure/gpu-tests.yml +++ b/.azure/gpu-tests.yml @@ -17,17 +17,17 @@ jobs: matrix: # CUDA 12.1 'ubuntu22.04 | cuda 12.1 | python 3.10 | torch 2.2 | regular': - docker-image: 'pytorchlightning/lightning-thunder:ubuntu22.04-cuda12.1.1-cudnn-fe1.1.0-py3.10-pt_2.2.1' + docker-image: 'pytorchlightning/lightning-thunder:ubuntu22.04-cuda12.1.1-cudnn-fe1.2.1-py3.10-pt_2.2.1' CUDA_VERSION_MM: '121' 'ubuntu22.04 | cuda 12.1 | python 3.10 | torch 2.2 | distributed': - docker-image: 'pytorchlightning/lightning-thunder:ubuntu22.04-cuda12.1.1-cudnn-fe1.1.0-py3.10-pt_2.2.1' + docker-image: 'pytorchlightning/lightning-thunder:ubuntu22.04-cuda12.1.1-cudnn-fe1.2.1-py3.10-pt_2.2.1' CUDA_VERSION_MM: '121' testing: 'distributed' 'ubuntu22.04 | cuda 12.1 | python 3.10 | torch-nightly | regular': - docker-image: 'pytorchlightning/lightning-thunder:ubuntu22.04-cuda12.1.1-cudnn-fe1.2.0-py3.10-pt_main' + docker-image: 'pytorchlightning/lightning-thunder:ubuntu22.04-cuda12.1.1-cudnn-fe1.2.1-py3.10-pt_main' CUDA_VERSION_MM: '121' 'ubuntu22.04 | cuda 12.1 | python 3.10 | torch-nightly | distributed': - docker-image: 'pytorchlightning/lightning-thunder:ubuntu22.04-cuda12.1.1-cudnn-fe1.2.0-py3.10-pt_main' + docker-image: 'pytorchlightning/lightning-thunder:ubuntu22.04-cuda12.1.1-cudnn-fe1.2.1-py3.10-pt_main' CUDA_VERSION_MM: '121' testing: 'distributed' # how long to run the job before automatically cancelling diff --git a/dockers/ubuntu-cuda/Dockerfile b/dockers/ubuntu-cuda/Dockerfile index 2213921897..bf28411ece 100644 --- a/dockers/ubuntu-cuda/Dockerfile +++ b/dockers/ubuntu-cuda/Dockerfile @@ -20,7 +20,7 @@ ARG IMAGE_TYPE="devel" FROM nvidia/cuda:${CUDA_VERSION}-${IMAGE_TYPE}-ubuntu${UBUNTU_VERSION} ARG CUDNN_VERSION="8.9.7.29-1" -ARG CUDNN_FRONTEND_CHECKOUT="v1.1.0" +ARG CUDNN_FRONTEND_CHECKOUT="v1.2.1" ARG PYTHON_VERSION="3.10" ARG TORCH_VERSION="2.2.1" ARG TRITON_VERSION="2.2.0" diff --git a/docs/source/fundamentals/installation.rst b/docs/source/fundamentals/installation.rst index 8d41a24047..a426f947b6 100644 --- a/docs/source/fundamentals/installation.rst +++ b/docs/source/fundamentals/installation.rst @@ -39,7 +39,7 @@ Thunder can use NVIDIA's cuDNN Python frontend bindings to accelerate some PyTor export CUDNN_PATH=/usr/local/lib/python3.10/dist-packages/nvidia/cudnn/ for file in $CUDNN_PATH/lib/*.so.[0-9]; do filename_without_version="${file%??}"; ln -s $file $filename_without_version; done - git clone -b v1.1.0 https://github.com/NVIDIA/cudnn-frontend.git + git clone -b v1.2.1 https://github.com/NVIDIA/cudnn-frontend.git export CUDAToolkit_ROOT=/path/to/cuda CMAKE_BUILD_PARALLEL_LEVEL=16 pip install cudnn_frontend/ -v