diff --git a/docs/source/index.rst b/docs/source/index.rst index 9507b84129..c7947a2b27 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -102,6 +102,7 @@ The compiled function ``jitted_foo`` takes and returns PyTorch tensors, just lik FSDP Under the Hood Tutorial Benchmarking Thunder Writing a Transform + Using Liger-Kernel with Thunder .. toctree:: :maxdepth: 1 @@ -143,6 +144,7 @@ API reference reference/torch/index reference/extend/index reference/transforms/index + reference/dynamo/index Indices and tables diff --git a/docs/source/reference/dynamo/index.rst b/docs/source/reference/dynamo/index.rst new file mode 100644 index 0000000000..d05a21e284 --- /dev/null +++ b/docs/source/reference/dynamo/index.rst @@ -0,0 +1,9 @@ +.. module:: thunder.dynamo + +thunder.dynamo +============== + +.. autosummary:: + :toctree: + + ThunderCompiler diff --git a/notebooks/liger_kernel.ipynb b/notebooks/liger_kernel.ipynb new file mode 100644 index 0000000000..20265b116f --- /dev/null +++ b/notebooks/liger_kernel.ipynb @@ -0,0 +1,1010 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "2055f09e-ea78-4726-9a48-65115000b140", + "metadata": {}, + "source": [ + "# Thunder bindings for Liger operators\n", + "\n", + "In this notebook we explore Thunder Bindings for Liger Operators.\n", + "\n", + "It is based on [Episode 10 of the Thunder Sessions podcast](https://www.youtube.com/watch?v=3H_aw6o-d9c&list=PLaMu-SDt_RB7ImARcTT_Wjypwx2vBIBen&index=10).\n", + "\n", + "Let's import things." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "8f4102a8-f68b-4012-bd5b-a1d3daeab367", + "metadata": {}, + "outputs": [], + "source": [ + "from collections.abc import Sequence\n", + "import math\n", + "\n", + "import torch\n", + "from torch.testing import assert_close\n", + "import litgpt\n", + "import thunder\n", + "from thunder.core.proxies import TensorProxy, AnyProxy\n", + "from thunder.core.transforms import get_grad, put_grads\n", + "from thunder.torch import TensorLike\n", + "import thunder.extend\n", + "\n", + "import liger_kernel.ops.rms_norm\n", + "import liger_kernel.ops.rope\n", + "import liger_kernel.ops.swiglu\n", + "import liger_kernel.ops.geglu # TODO\n", + "import liger_kernel.ops.cross_entropy # TODO\n", + "import liger_kernel.ops.fused_linear_cross_entropy\n", + "\n", + "device = torch.device(\"cuda\")" + ] + }, + { + "cell_type": "markdown", + "id": "6b44643e-a92c-4398-861f-793cec2e7414", + "metadata": {}, + "source": [ + "We define and register an executor." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "c5232472-a67c-4650-abf9-370e4692e93d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "thunder.extend.OperatorExecutor('liger')" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "liger_ex = thunder.extend.OperatorExecutor(\"liger\", version=\"0.1\")\n", + "thunder.extend.register_executor(liger_ex)" + ] + }, + { + "cell_type": "markdown", + "id": "b207657e-a40c-4cda-a2d6-3f0e11ae4949", + "metadata": {}, + "source": [ + "## RMS Norm\n", + "\n", + "The first thing to fuse is RMS Norm.\n", + "\n", + "After that, Liger's implementation is a drop-in replacement. We define operators for forward and backward and then a gradient and execution rule.\n", + "\n", + "We register these as an implementation for the rms_norm operand that we divert the PyTorch function to." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "4411cc5f-5535-48e2-ba7c-00b984f15ad2", + "metadata": {}, + "outputs": [], + "source": [ + "# A tiny detail here is that PyTorch gained a `rms_norm` function somewhat\n", + "# recently and we need to tell LitGPT to use it.\n", + "\n", + "\n", + "def RMSNorm_forward(self, x):\n", + " return torch.nn.functional.rms_norm(x, self.weight.shape, self.weight, self.eps)\n", + "\n", + "\n", + "litgpt.model.RMSNorm.forward = RMSNorm_forward" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "77757535-b292-4a96-a6a3-c0e7f05d70ea", + "metadata": {}, + "outputs": [], + "source": [ + "import functools\n", + "\n", + "prod = lambda *args: functools.reduce(lambda x, y: x * y, args)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "f542954c-aba1-4523-9a7d-436348a6af96", + "metadata": {}, + "outputs": [], + "source": [ + "# ******************************* RMS NORM *******************************\n", + "import functools\n", + "\n", + "\n", + "def liger_rms_norm_forward_meta(X, W, eps, offset, casting_mode):\n", + " *n_rows, n_cols = X.shape\n", + " n_rows = prod(*n_rows)\n", + " # RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode\n", + " rstd_dtype = (\n", + " thunder.dtypes.float32\n", + " if casting_mode\n", + " in (liger_kernel.ops.rms_norm._CASTING_MODE_LLAMA.value, liger_kernel.ops.rms_norm._CASTING_MODE_GEMMA.value)\n", + " else X.dtype\n", + " )\n", + " Y = TensorProxy(like=X)\n", + " RSTD = TensorProxy(like=X, shape=(n_rows,), dtype=rstd_dtype)\n", + " BLOCK_SIZE, num_warps = liger_kernel.ops.rms_norm.calculate_settings(n_cols)\n", + " return Y, TensorProxy(like=X, shape=(n_rows, n_cols)), RSTD, BLOCK_SIZE, num_warps, casting_mode\n", + "\n", + "\n", + "liger_rms_norm_forward = liger_ex.register_operator(\n", + " \"liger_rms_norm_forward\", meta=liger_rms_norm_forward_meta, fn=liger_kernel.ops.rms_norm.rms_norm_forward\n", + ")\n", + "\n", + "\n", + "def liger_rms_norm_backward_meta(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps):\n", + " return TensorProxy(like=X), TensorProxy(like=W)\n", + "\n", + "\n", + "liger_rms_norm_backward = liger_ex.register_operator(\n", + " \"liger_rms_norm_backward\", meta=liger_rms_norm_backward_meta, fn=liger_kernel.ops.rms_norm.rms_norm_backward\n", + ")\n", + "\n", + "\n", + "def rms_norm_meta(x, shape, w, eps):\n", + " return thunder.TensorProxy(like=x)\n", + "\n", + "\n", + "rms_norm = liger_ex.register_operator(\n", + " \"rms_norm\", meta=rms_norm_meta, fn=torch.nn.functional.rms_norm, replaces=torch.nn.functional.rms_norm\n", + ")\n", + "\n", + "\n", + "def rms_norm_grad_transform(x, shape, weight, eps):\n", + " Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = liger_rms_norm_forward(\n", + " x, weight, eps, offset=0.0, casting_mode=\"llama\"\n", + " )\n", + " dY = get_grad(Y)\n", + " dX, dW = liger_rms_norm_backward(\n", + " dY, X, weight, RSTD, offset=0.0, casting_mode=\"llama\", BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps\n", + " )\n", + " dX = dX.view(*x.shape)\n", + " put_grads((x, weight), (dX, dW))\n", + " return Y\n", + "\n", + "\n", + "def rms_norm_execution_transform(x, weight, eps):\n", + " Y, *_ = liger_rms_norm_forward(x, weight, eps, offset=0.0, casting_mode=\"llama\")\n", + " return Y\n", + "\n", + "\n", + "liger_ex.register_implementation(\n", + " rms_norm, execution_transform=rms_norm_execution_transform, grad_transform=rms_norm_grad_transform\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "0ace1ad2-25f4-4a20-ad39-1f030bca0f38", + "metadata": {}, + "source": [ + "### Testing RMS Norm\n", + "\n", + "Let's test." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "56f1d6ee-a4ac-42f1-9d65-2c774cda4d18", + "metadata": {}, + "outputs": [], + "source": [ + "hidden_size = 64\n", + "\n", + "example_input = torch.randn(32, 10, hidden_size, device=device, requires_grad=True)\n", + "\n", + "with device:\n", + " model = litgpt.model.RMSNorm(hidden_size)\n", + "thunder_model = thunder.jit(model, executors=[liger_ex])\n", + "ref = model(example_input.clone())\n", + "res = thunder_model(example_input.clone())\n", + "go = torch.randn_like(ref)\n", + "grad_ref, grad_ref_weight = torch.autograd.grad(ref, (example_input, model.weight), go)\n", + "grad_res, grad_res_weight = torch.autograd.grad(res, (example_input, model.weight), go)\n", + "\n", + "\n", + "assert liger_rms_norm_forward in {bsym.sym for bsym in thunder.last_traces(thunder_model)[-1].bound_symbols}\n", + "assert liger_rms_norm_backward in {bsym.sym for bsym in thunder.last_backward_traces(thunder_model)[-1].bound_symbols}\n", + "\n", + "assert_close(ref, res)\n", + "assert_close(grad_ref, grad_res)\n", + "assert_close(grad_ref_weight, grad_res_weight)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "60dcb262-2255-4c17-b64f-f38e8ebd8e33", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "71c49c38-ce84-4727-9f57-8b42b908349f", + "metadata": {}, + "source": [ + "# RoPE\n", + "\n", + "Next is the RoPE implementation. Liger does both rope applications to query and key in one kernel whereas\n", + "LitGPT uses two. So we define not only forward and backward and a symbol to capture the litgpt version,\n", + "but also a small transform fusing the two `apply_rope` calls to one `liger_rope`." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "32cd98f0-a36f-4e01-8ae4-6e36adf2699b", + "metadata": {}, + "outputs": [], + "source": [ + "def liger_rope_forward_meta(q, k, cos, sin):\n", + " return TensorProxy(like=q), TensorProxy(like=k), cos, sin\n", + "\n", + "\n", + "liger_rope_forward = liger_ex.register_operator(\n", + " \"liger_rope_forward\",\n", + " meta=liger_rope_forward_meta,\n", + " fn=liger_kernel.ops.rope.rope_forward,\n", + ")\n", + "\n", + "\n", + "def liger_rope_backward_meta(dq, dk, cos, sin):\n", + " return TensorLike(like=dq), TensorLike(like=dk)\n", + "\n", + "\n", + "liger_rope_backward = liger_ex.register_operator(\n", + " \"liger_rope_backward\",\n", + " meta=liger_rope_backward_meta,\n", + " fn=liger_kernel.ops.rope.rope_backward,\n", + ")\n", + "\n", + "\n", + "def liger_rope_grad_transform(q, k, cos, sin):\n", + " q_out, k_out, _, _ = liger_rope_forward(q, k, cos, sin)\n", + " q_out_grad = get_grad(q_out)\n", + " k_out_grad = get_grad(k_out)\n", + " dq, dk = liger_rope_backward(q_out_grad, k_out_grad, cos, sin)\n", + " put_grads((q, k), (dq, dk))\n", + " return q_out, k_out\n", + "\n", + "\n", + "def liger_rope_execution_transform(q, k, cos, sin):\n", + " q_out, k_out, _, _ = liger_rope_forward(q, k, cos, sin)\n", + " return q_out, k_out\n", + "\n", + "\n", + "def liger_rope_impl(q, k, cos, sin):\n", + " qr, kr, _, _ = liger_rope_forward(q, k, cos, sin)\n", + " return qr, kr\n", + "\n", + "\n", + "liger_rope = liger_ex.register_operator(\"liger_rope\", fn=liger_rope_impl, like=liger_rope_impl)\n", + "\n", + "liger_ex.register_implementation(\n", + " liger_rope,\n", + " execution_transform=liger_rope_execution_transform,\n", + " grad_transform=liger_rope_grad_transform,\n", + ")\n", + "\n", + "\n", + "def litgpt_apply_rope_meta(x, cos, sin):\n", + " return TensorProxy(like=x)\n", + "\n", + "\n", + "litgpt_apply_rope = liger_ex.register_operator(\n", + " \"litgpt_apply_rope\", fn=litgpt.model.apply_rope, meta=litgpt_apply_rope_meta, replaces=litgpt.model.apply_rope\n", + ")\n", + "\n", + "\n", + "class MergeRopeTransform(thunder.core.transform_common.Transform):\n", + " def transform_traces_pre_prologue(self, prologue_trace, compute_trace, epilogue_trace, **kwargs):\n", + " new_compute_trace = thunder.core.trace.from_trace(compute_trace)\n", + " bound_symbols = compute_trace.bound_symbols[:]\n", + " while bound_symbols:\n", + " bsym = bound_symbols.pop(0)\n", + " if bsym.sym == litgpt_apply_rope:\n", + " for i, bsym2 in enumerate(bound_symbols):\n", + " assert not any(o is bsym.output for o in bsym2.flat_outs)\n", + " if bsym2.sym == litgpt_apply_rope:\n", + " break\n", + " bsym2 = bound_symbols.pop(i)\n", + " assert bsym2.sym == litgpt_apply_rope\n", + "\n", + " output = (bsym.output, bsym2.output)\n", + " args = (bsym.args[0], bsym2.args[0], *bsym.args[1:])\n", + "\n", + " new_compute_trace.bound_symbols.append(bsym.from_bsym(args=args, output=output, sym=liger_rope))\n", + " else:\n", + " new_compute_trace.bound_symbols.append(bsym.from_bsym())\n", + " new_compute_trace.set_provenance(thunder.core.trace.TraceProvenance(self.__class__))\n", + " return prologue_trace, new_compute_trace, epilogue_trace" + ] + }, + { + "cell_type": "markdown", + "id": "44187b29-c101-41f0-a811-4c9f29757c81", + "metadata": {}, + "source": [ + "# Test\n", + "\n", + "We test with a scaled-down Llama." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "b8fd2563-7b89-487d-8fdc-21661380e2c0", + "metadata": {}, + "outputs": [], + "source": [ + "cfg = litgpt.Config.from_name(\"Llama-3.2-1B\", n_layer=1)\n", + "with device:\n", + " m = litgpt.GPT(cfg)\n", + " m.max_seq_length = 1024\n", + " m.set_kv_cache(1)\n", + " inp = torch.arange(1, 6, dtype=torch.int64)[None]\n", + " inp_pos = torch.arange(1, 6, dtype=torch.int64)\n", + "\n", + "\n", + "jm = thunder.jit(m, executors=(liger_ex,), transforms=(MergeRopeTransform(),))\n", + "res = jm(inp, inp_pos)\n", + "ref = m(inp, inp_pos)\n", + "\n", + "go = torch.randn_like(res)\n", + "(grad_res,) = torch.autograd.grad(res, jm.get_parameter(\"transformer.wte.weight\"), go)\n", + "(grad_ref,) = torch.autograd.grad(ref, m.get_parameter(\"transformer.wte.weight\"), go)\n", + "\n", + "assert_close(res, ref)\n", + "assert_close(grad_res, grad_ref)\n", + "\n", + "assert any(bsym.sym is liger_rope_forward for bsym in thunder.last_traces(jm)[-1].bound_symbols)\n", + "assert any(bsym.sym is liger_rope_backward for bsym in thunder.last_backward_traces(jm)[-1].bound_symbols)\n", + "assert any(bsym.sym is liger_rms_norm_forward for bsym in thunder.last_traces(jm)[-1].bound_symbols)\n", + "assert any(bsym.sym is liger_rms_norm_backward for bsym in thunder.last_backward_traces(jm)[-1].bound_symbols)" + ] + }, + { + "cell_type": "markdown", + "id": "e341460b-71d4-4e83-b67c-14bdec7d8026", + "metadata": {}, + "source": [ + "## SwiGLU\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "5e26e9ec-3eb0-46e8-90d1-8cd84d5dd1b7", + "metadata": {}, + "outputs": [], + "source": [ + "def liger_swiglu_forward_meta(a, b):\n", + " return TensorProxy(like=a)\n", + "\n", + "\n", + "def liger_swiglu_forward_impl(a, b):\n", + " _, _, res = liger_kernel.ops.swiglu.swiglu_forward(a, b)\n", + " return res\n", + "\n", + "\n", + "liger_swiglu_forward = liger_ex.register_operator(\n", + " \"liger_swiglu_forward\",\n", + " meta=liger_swiglu_forward_meta,\n", + " fn=liger_swiglu_forward_impl,\n", + ")\n", + "\n", + "\n", + "def liger_swiglu_backward_meta(a, b, grad_res):\n", + " return TensorProxy(like=a), TensorProxy(like=b)\n", + "\n", + "\n", + "liger_swiglu_backward = liger_ex.register_operator(\n", + " \"liger_swiglu_backward\",\n", + " meta=liger_swiglu_backward_meta,\n", + " fn=liger_kernel.ops.swiglu.swiglu_backward,\n", + ")\n", + "\n", + "\n", + "def liger_swiglu_gradient_transform(a, b):\n", + " res = liger_swiglu_forward(a, b)\n", + " grad_res = get_grad(res)\n", + " grad_a, grad_b = liger_swiglu_backward(a, b, grad_res)\n", + " put_grads((a, b), (grad_a, grad_b))\n", + " return res\n", + "\n", + "\n", + "liger_ex.register_implementation(\n", + " liger_swiglu_forward, grad_transform=liger_swiglu_gradient_transform, execution_transform=liger_swiglu_forward\n", + ")\n", + "\n", + "\n", + "class FuseSwigLUTransform(thunder.core.transform_common.Transform):\n", + " def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilogue_trace, **kwargs):\n", + " _, consumers = thunder.core.utils.producers_and_consumers(computation_trace)\n", + " new_computation_trace = thunder.core.trace.from_trace(computation_trace)\n", + " bsyms_to_skip = set()\n", + " for b in computation_trace.bound_symbols:\n", + " if b in bsyms_to_skip:\n", + " continue\n", + " new_bsym = b\n", + " if b.sym == thunder.torch.silu:\n", + " c = consumers[b.output]\n", + " if len(c) == 1 and c[0].sym == thunder.torch.mul:\n", + " (mul,) = c\n", + " mul_l, mul_r = mul.args\n", + " if mul_l is b.output:\n", + " other = mul_r\n", + " else:\n", + " other = mul_l\n", + " new_bsym = b.from_bsym(\n", + " sym=liger_swiglu_forward, output=mul.output, args=(b.args[0], other), subsymbols=[]\n", + " )\n", + " bsyms_to_skip.add(mul)\n", + " new_computation_trace.bound_symbols.append(new_bsym)\n", + " new_computation_trace.set_provenance(thunder.core.trace.TraceProvenance(\"constructed by FuseSwigLU\"))\n", + " return prologue_trace, new_computation_trace, epilogue_trace" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c004b1f6-9756-44ae-88f3-d088f0c838c6", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "efb5dcb5-a411-4832-9cac-a868dc3142b0", + "metadata": {}, + "source": [ + "## Fused Linear and Cross Entropy" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "55ff0b33-99ec-4de1-a3c0-7a78ebdf83c4", + "metadata": {}, + "outputs": [], + "source": [ + "def liger_fused_linear_cross_entropy_forward_meta(\n", + " _input, weight, target, bias=None, ignore_index=-100, label_smoothing=0.0, reduction=\"mean\"\n", + "):\n", + " logits = thunder.torch.linear(_input, weight, bias)\n", + " loss = thunder.torch.cross_entropy(\n", + " logits, target, ignore_index=ignore_index, label_smoothing=label_smoothing, reduction=reduction\n", + " )\n", + " grad_input = TensorProxy(like=_input)\n", + " grad_weight = TensorProxy(like=weight)\n", + " grad_bias = None if bias is None else TensorProxy(like=bias)\n", + " return loss, grad_input, grad_weight, grad_bias\n", + "\n", + "\n", + "liger_fused_linear_cross_entropy_forward = liger_ex.register_operator(\n", + " \"liger_fused_linear_cross_entropy_forward\",\n", + " fn=liger_kernel.ops.fused_linear_cross_entropy.fused_linear_cross_entropy_forward,\n", + " like=liger_fused_linear_cross_entropy_forward_meta,\n", + ")\n", + "\n", + "\n", + "def liger_fused_linear_cross_entropy_backward_meta(grad_output, grad_input, grad_weight, grad_bias):\n", + " return (\n", + " TensorProxy(like=grad_input),\n", + " TensorProxy(like=grad_weight),\n", + " (TensorProxy(like=grad_bias) if grad_bias is not None else None),\n", + " )\n", + "\n", + "\n", + "liger_fused_linear_cross_entropy_backward = liger_ex.register_operator(\n", + " \"liger_fused_linear_cross_entropy_backward\",\n", + " fn=liger_kernel.ops.fused_linear_cross_entropy.fused_linear_cross_entropy_backward,\n", + " meta=liger_fused_linear_cross_entropy_backward_meta,\n", + ")\n", + "\n", + "\n", + "def liger_fused_linear_cross_entropy_grad_transform(\n", + " _input, weight, target, bias=None, ignore_index=-100, label_smoothing=0.0, reduction=\"mean\"\n", + "):\n", + " loss, grad_input_1, grad_weight_1, grad_bias_1 = liger_fused_linear_cross_entropy_forward(\n", + " _input,\n", + " weight,\n", + " target,\n", + " bias=bias,\n", + " ignore_index=ignore_index,\n", + " label_smoothing=label_smoothing,\n", + " reduction=reduction,\n", + " )\n", + " grad_loss = get_grad(loss)\n", + " grad_input, grad_weight, grad_bias = liger_fused_linear_cross_entropy_backward(\n", + " grad_loss, grad_input_1, grad_weight_1, grad_bias_1\n", + " )\n", + " put_grads((_input, weight, target), (grad_input, grad_weight, grad_bias))\n", + " return loss\n", + "\n", + "\n", + "liger_ex.register_implementation(\n", + " liger_fused_linear_cross_entropy_forward,\n", + " grad_transform=liger_fused_linear_cross_entropy_grad_transform,\n", + " execution_transform=liger_fused_linear_cross_entropy_forward,\n", + ")\n", + "\n", + "\n", + "class FuseLinearCrossEntropyTransform(thunder.core.transform_common.Transform):\n", + " def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilogue_trace, **kwargs):\n", + " _, consumers = thunder.core.utils.producers_and_consumers(computation_trace)\n", + " new_computation_trace = thunder.core.trace.from_trace(computation_trace)\n", + " bsyms_to_skip = set()\n", + " for b in computation_trace.bound_symbols:\n", + " if b in bsyms_to_skip:\n", + " continue\n", + " new_bsym = b\n", + " if b.sym == thunder.torch.linear:\n", + " c = consumers[b.output]\n", + " if len(c) == 1 and c[0].sym == thunder.torch.cross_entropy:\n", + " (ce,) = c\n", + " assert not ce.kwargs\n", + " assert not b.kwargs\n", + " assert ce.args[0] is b.output\n", + " inp, weight, bias = b.args\n", + " _, targets, ce_weight, size_average, ignore_index, reduce, reduction, label_smoothing = ce.args\n", + " assert ce_weight is None\n", + " assert size_average is None\n", + " assert reduce is None\n", + " new_bsym = b.from_bsym(\n", + " sym=liger_fused_linear_cross_entropy_forward,\n", + " output=ce.output,\n", + " args=(inp, weight, targets, bias, ignore_index, label_smoothing, reduction),\n", + " subsymbols=[],\n", + " )\n", + " bsyms_to_skip.add(ce)\n", + " new_computation_trace.bound_symbols.append(new_bsym)\n", + " new_computation_trace.set_provenance(\n", + " thunder.core.trace.TraceProvenance(\"constructed by FuseLinearCrossEntropy\")\n", + " )\n", + " return prologue_trace, new_computation_trace, epilogue_trace" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "89431922-f074-4825-a6b6-7365abe5b0b4", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "def apply_eye_meta(x):\n", + " return thunder.TensorProxy(like=x)\n", + "\n", + "\n", + "def apply_eye(mask):\n", + " mask = mask | torch.eye(mask.shape[-1], dtype=torch.bool, device=mask.device)[None, None]\n", + " return mask\n", + "\n", + "\n", + "t_apply_eye = liger_ex.register_operator(\"t_apply_eye\", fn=apply_eye, meta=apply_eye_meta, replaces=apply_eye)\n", + "\n", + "\n", + "def apply_eye_grad_transform(x):\n", + " return t_apply_eye(x)\n", + "\n", + "\n", + "liger_ex.register_implementation(\n", + " t_apply_eye, execution_transform=apply_eye_grad_transform, grad_transform=apply_eye_grad_transform\n", + ")\n", + "\n", + "\n", + "class GPTForFineTuningLastToken(litgpt.model.GPT):\n", + " def forward(self, idx: torch.Tensor, *, mask: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:\n", + " mask = mask.bool()\n", + " T = idx.size(1)\n", + " if self.max_seq_length < T:\n", + " raise ValueError(f\"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.\")\n", + "\n", + " attn_mask = (\n", + " litgpt.model.build_mask_cache(mask.shape[-1], mask.device).expand(4, -1, -1, -1) * mask[:, None, None, :]\n", + " )\n", + " attn_mask = apply_eye(attn_mask)\n", + "\n", + " cos = self.cos[:T]\n", + " sin = self.sin[:T]\n", + " x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)\n", + " if self.config.scale_embeddings:\n", + " x = x * torch.tensor(self.config.n_embd**0.5, dtype=x.dtype)\n", + "\n", + " for block in self.transformer.h:\n", + " x = block(x, cos, sin, attn_mask, None)\n", + "\n", + " # second to last prediction is the output\n", + " x = x[:, -2]\n", + " x = self.transformer.ln_f(x)\n", + " x = self.lm_head(x) # (b, t, vocab_size)\n", + " if self.config.final_logit_softcapping is not None:\n", + " x = torch.tanh(x / self.config.final_logit_softcapping) * self.config.final_logit_softcapping\n", + " loss = torch.nn.functional.cross_entropy(x, labels)\n", + " return loss\n", + "\n", + "\n", + "cfg = litgpt.Config.from_name(\"Llama-3.2-1B\", n_layer=1)\n", + "with device:\n", + " m = GPTForFineTuningLastToken(cfg)\n", + " m.max_seq_length = 1024\n", + " inp = torch.ones(4, 32, dtype=torch.int64)\n", + " mask = torch.ones(4, 32, dtype=torch.int64)\n", + " labels = torch.ones(4, dtype=torch.int64)\n", + "\n", + "\n", + "jm = thunder.jit(\n", + " m,\n", + " executors=(liger_ex,),\n", + " transforms=(\n", + " MergeRopeTransform(),\n", + " FuseSwigLUTransform(),\n", + " FuseLinearCrossEntropyTransform(),\n", + " ),\n", + ")\n", + "res = jm(inp, mask=mask, labels=labels)\n", + "ref = m(inp, mask=mask, labels=labels)\n", + "\n", + "go = torch.randn_like(res)\n", + "(grad_res,) = torch.autograd.grad(res, jm.get_parameter(\"transformer.wte.weight\"), go)\n", + "(grad_ref,) = torch.autograd.grad(ref, m.get_parameter(\"transformer.wte.weight\"), go)\n", + "\n", + "assert_close(res, ref)\n", + "assert_close(grad_res, grad_ref)\n", + "\n", + "assert any(bsym.sym is liger_rope_forward for bsym in thunder.last_traces(jm)[-1].bound_symbols)\n", + "assert any(bsym.sym is liger_rope_backward for bsym in thunder.last_backward_traces(jm)[-1].bound_symbols)\n", + "assert any(bsym.sym is liger_rms_norm_forward for bsym in thunder.last_traces(jm)[-1].bound_symbols)\n", + "assert any(bsym.sym is liger_rms_norm_backward for bsym in thunder.last_backward_traces(jm)[-1].bound_symbols)\n", + "assert any(bsym.sym is liger_swiglu_forward for bsym in thunder.last_traces(jm)[-1].bound_symbols)\n", + "assert any(bsym.sym is liger_swiglu_backward for bsym in thunder.last_backward_traces(jm)[-1].bound_symbols)\n", + "assert any(bsym.sym is liger_fused_linear_cross_entropy_forward for bsym in thunder.last_traces(jm)[-1].bound_symbols)\n", + "assert any(\n", + " bsym.sym is liger_fused_linear_cross_entropy_backward for bsym in thunder.last_backward_traces(jm)[-1].bound_symbols\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "2812d18c-c6a6-4d02-ba12-7a8002efc0e5", + "metadata": {}, + "source": [ + "# End to end example\n", + "\n", + "adapted from a [Liger-Kernel example](https://github.com/linkedin/Liger-Kernel/blob/de12602d858a6e83aaacc56e5cb64ab218c75a0a/examples/lightning/training.py).\n", + "\n", + "Code below is\n", + "\n", + "Copyright 2024 LinkedIn Corporation ([BSD 2-CLAUSE LICENSE](https://github.com/linkedin/Liger-Kernel/blob/de12602d858a6e83aaacc56e5cb64ab218c75a0a/LICENSE))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "420850af-878a-415a-bacd-a0d3258a0cc3", + "metadata": {}, + "outputs": [], + "source": [ + "if False: # this example has additional dependencies, so we skip it in the CI\n", + " import argparse\n", + " import math\n", + " import os\n", + " from dataclasses import _MISSING_TYPE, dataclass\n", + " import litgpt\n", + " \n", + " import datasets\n", + " import lightning.pytorch as pl\n", + " import torch\n", + " import transformers\n", + " from lightning.pytorch.strategies import DeepSpeedStrategy, FSDPStrategy\n", + " from torch.distributed.fsdp import BackwardPrefetch, MixedPrecision\n", + " from torch.utils.data import DataLoader\n", + " from trl import DataCollatorForCompletionOnlyLM\n", + " import warnings\n", + " \n", + " warnings.simplefilter(action=\"ignore\", category=FutureWarning)\n", + " \n", + " \n", + " _RETAIN_COLUMNS = {\"input_ids\", \"attention_mask\", \"labels\"}\n", + " QUESTION = \"\"\n", + " CHOICES = \"\"\n", + " \n", + " \n", + " @dataclass\n", + " class Args:\n", + " model: str = \"meta-llama/Llama-3.2-1B-Instruct\"\n", + " data: str = \"cais/mmlu\"\n", + " output_dir: str = \"mmlu_finetuning\"\n", + " max_length: int = 2048\n", + " # for llam3 8B model, deepspeed will OOM with 16 on 8XA100 80G and 8 will OOM with 8XA100 40G\n", + " batch_size: int = 4\n", + " lr: float = 6e-6\n", + " weight_decay: float = 0.05\n", + " warmup_ratio: float = 0.1\n", + " seed: int = 42\n", + " strategy: str = \"auto\"\n", + " num_gpu: int = 1\n", + " \n", + " \n", + " def warmup_cosine_schedule(warmup_steps, total_steps, min_lr=0):\n", + " def lr_lambda(current_step):\n", + " if current_step < warmup_steps:\n", + " # Linear warmup\n", + " return float(current_step) / float(max(1, warmup_steps))\n", + " else:\n", + " # Cosine annealing\n", + " progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))\n", + " return max(min_lr, 0.5 * (1 + math.cos(math.pi * progress)))\n", + " \n", + " return lr_lambda\n", + " \n", + " \n", + " def parse_args() -> Args:\n", + " parser = argparse.ArgumentParser()\n", + " for k, v in Args.__dataclass_fields__.items():\n", + " parser.add_argument(f\"--{k}\", type=v.type, default=v.default)\n", + " parsed = parser.parse_args([])\n", + " return Args(**{k: v for k, v in vars(parsed).items() if not isinstance(v, _MISSING_TYPE)})\n", + " \n", + " \n", + " class LanguageModel(pl.LightningModule):\n", + " def __init__(self, args: Args, tokenizer):\n", + " super().__init__()\n", + " self.args = args\n", + " self.tokenizer = tokenizer\n", + " self.model = None\n", + " \n", + " def configure_model(self):\n", + " # https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/fsdp.html#speed-up-model-initialization\n", + " if self.model is not None:\n", + " return\n", + " self.model = GPTForFineTuningLastToken.from_name(self.args.model.rsplit(\"/\", 1)[-1]).to(torch.bfloat16)\n", + " self.model.load_state_dict(litgpt.utils.lazy_load(f\"checkpoints/{self.args.model}/lit_model.pth\"))\n", + " self.model = thunder.jit(\n", + " self.model,\n", + " executors=(liger_ex, *thunder.get_default_executors()),\n", + " transforms=(MergeRopeTransform(), FuseSwigLUTransform(), FuseLinearCrossEntropyTransform()),\n", + " )\n", + " \n", + " def forward(self, input_ids, attention_mask, labels=None, **kwargs):\n", + " return self.model(idx=input_ids, mask=attention_mask, labels=labels, **kwargs)\n", + " \n", + " def training_step(self, batch):\n", + " outputs = self.model(\n", + " idx=batch[\"input_ids\"],\n", + " mask=batch[\"attention_mask\"],\n", + " labels=batch[\"labels\"][:, -1],\n", + " )\n", + " loss = outputs\n", + " self.log_dict(\n", + " {\"train_loss\": loss},\n", + " on_step=True,\n", + " on_epoch=True,\n", + " prog_bar=True,\n", + " logger=True,\n", + " rank_zero_only=True,\n", + " sync_dist=False,\n", + " )\n", + " return loss\n", + " \n", + " def validation_step(self, batch):\n", + " outputs = self.model(\n", + " idx=batch[\"input_ids\"],\n", + " mask=batch[\"attention_mask\"],\n", + " labels=batch[\"labels\"][:, -1],\n", + " )\n", + " loss = outputs\n", + " self.log_dict(\n", + " {\"val_loss\": loss},\n", + " on_step=True,\n", + " on_epoch=True,\n", + " prog_bar=True,\n", + " logger=True,\n", + " rank_zero_only=True,\n", + " sync_dist=True,\n", + " )\n", + " return loss\n", + " \n", + " def configure_optimizers(self):\n", + " optimizer = torch.optim.AdamW(\n", + " self.parameters(),\n", + " lr=self.args.lr,\n", + " weight_decay=self.args.weight_decay,\n", + " fused=True,\n", + " )\n", + " lr_lambda = warmup_cosine_schedule(\n", + " warmup_steps=self.trainer.estimated_stepping_batches * self.args.warmup_ratio,\n", + " total_steps=self.trainer.estimated_stepping_batches,\n", + " min_lr=0,\n", + " )\n", + " lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)\n", + " return {\n", + " \"optimizer\": optimizer,\n", + " \"lr_scheduler\": {\"scheduler\": lr_scheduler, \"interval\": \"step\"},\n", + " }\n", + " \n", + " \n", + " class DataModule(pl.LightningDataModule):\n", + " def __init__(self, tokenizer, args: Args):\n", + " super().__init__()\n", + " self.train_dataset = None\n", + " self.args = args\n", + " self.tokenizer = tokenizer\n", + " self.response_template_str = \" \"\n", + " response_prompt = tokenizer.encode(f\"{self.response_template_str}\", add_special_tokens=False)\n", + " self.collator = DataCollatorForCompletionOnlyLM(\n", + " tokenizer=tokenizer,\n", + " response_template=response_prompt,\n", + " pad_to_multiple_of=16,\n", + " )\n", + " \n", + " def formatting_func(self, example):\n", + " output_texts = []\n", + " for i in range(len(example[\"question\"])):\n", + " choices = \"\"\n", + " for j in range(len(example[\"choices\"][i])):\n", + " choices += f\"{j+1}. {example['choices'][i][j]}; \"\n", + " s = \"Below is a question and multiple choice answers, choices separated by a semicolon. Please select the best answer for the question. \"\n", + " s += f\"{QUESTION}{example['question'][i]} \"\n", + " s += f\"{CHOICES}{choices} \"\n", + " s += f\"{self.response_template_str}{example['answer'][i]}\"\n", + " output_texts.append(s)\n", + " return output_texts\n", + " \n", + " def tokenize(self, example):\n", + " outputs = self.tokenizer(\n", + " self.formatting_func(example),\n", + " truncation=True,\n", + " padding=False,\n", + " max_length=self.args.max_length,\n", + " )\n", + " return {\n", + " \"input_ids\": outputs[\"input_ids\"],\n", + " \"attention_mask\": outputs[\"attention_mask\"],\n", + " }\n", + " \n", + " def setup(self, stage) -> None:\n", + " if self.train_dataset is not None:\n", + " return\n", + " dataset = datasets.load_dataset(self.args.data, \"auxiliary_train\")\n", + " flattened_data = [\n", + " {\n", + " \"answer\": x[\"train\"][\"answer\"],\n", + " \"choices\": x[\"train\"][\"choices\"],\n", + " \"question\": x[\"train\"][\"question\"],\n", + " \"subject\": x[\"train\"][\"subject\"],\n", + " }\n", + " for x in dataset[\"train\"]\n", + " ][:32]\n", + " dataset = datasets.Dataset.from_list(flattened_data)\n", + " dataset = dataset.train_test_split(test_size=4, seed=self.args.seed)\n", + " train_dataset, val_dataset = dataset[\"train\"], dataset[\"test\"]\n", + " self.train_dataset = train_dataset.map(\n", + " self.tokenize,\n", + " remove_columns=list(set(train_dataset.column_names) - _RETAIN_COLUMNS),\n", + " batched=True,\n", + " batch_size=1,\n", + " num_proc=4,\n", + " )\n", + " self.val_dataset = val_dataset.map(\n", + " self.tokenize,\n", + " remove_columns=list(set(val_dataset.column_names) - _RETAIN_COLUMNS),\n", + " batched=True,\n", + " batch_size=1,\n", + " num_proc=4,\n", + " )\n", + " \n", + " def train_dataloader(self):\n", + " return DataLoader(\n", + " self.train_dataset,\n", + " batch_size=self.args.batch_size,\n", + " collate_fn=self.collator,\n", + " )\n", + " \n", + " def val_dataloader(self):\n", + " return DataLoader(\n", + " self.val_dataset,\n", + " batch_size=self.args.batch_size,\n", + " collate_fn=self.collator,\n", + " )\n", + " \n", + " \n", + " args = parse_args()\n", + " pl.seed_everything(args.seed)\n", + " os.makedirs(args.output_dir, exist_ok=True)\n", + " \n", + " if args.strategy == \"fsdp\":\n", + " strategy = FSDPStrategy(\n", + " auto_wrap_policy=layers,\n", + " sharding_strategy=\"FULL_SHARD\",\n", + " backward_prefetch=BackwardPrefetch.BACKWARD_PRE,\n", + " sync_module_states=True,\n", + " activation_checkpointing_policy=layers,\n", + " mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16),\n", + " forward_prefetch=True,\n", + " )\n", + " precision = None\n", + " elif args.strategy == \"deepspeed\":\n", + " strategy = DeepSpeedStrategy(stage=3)\n", + " precision = \"bf16-mixed\"\n", + " elif args.strategy == \"ddp\":\n", + " strategy = \"ddp\"\n", + " precision = \"bf16-true\"\n", + " else:\n", + " strategy = \"auto\"\n", + " precision = \"bf16-true\"\n", + "\n", + " # This only works if you have a snapshot to work from.\n", + " trainer = pl.Trainer(\n", + " accelerator=\"cuda\",\n", + " strategy=strategy,\n", + " devices=torch.cuda.device_count() if args.num_gpu is None else args.num_gpu,\n", + " default_root_dir=args.output_dir,\n", + " log_every_n_steps=1,\n", + " max_epochs=1,\n", + " precision=precision,\n", + " )\n", + "\n", + " tokenizer = transformers.AutoTokenizer.from_pretrained(args.model, padding_side=\"left\", truncation_side=\"left\")\n", + " tokenizer.pad_token = tokenizer.eos_token\n", + " data_module = DataModule(\n", + " tokenizer=tokenizer,\n", + " args=args,\n", + " )\n", + "\n", + " model = LanguageModel(args=args, tokenizer=tokenizer)\n", + " trainer.fit(model, datamodule=data_module)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "434dbcf7-1ed3-4669-a90b-12044909be44", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/requirements/base.txt b/requirements/base.txt index fc9b17e751..e1c4766f6a 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -2,7 +2,7 @@ torch >=2.3.0 looseversion ==1.3.0 lightning-utilities >=0.7.0 numpy >=1.23.0,<2 # not yet ready for numpy 2 -igraph >=0.10.4 +networkx >= 3.3 optree >=0.12.1 opt_einsum >= 3.3.0 mpmath <1.4.0 # todo: teporarl pin for `NameError: name '_C' is not defined` diff --git a/requirements/notebooks.txt b/requirements/notebooks.txt index 065c36fe19..d9d6afd90c 100644 --- a/requirements/notebooks.txt +++ b/requirements/notebooks.txt @@ -1,6 +1,5 @@ ipython[all] ~=8.27.0 numpy >=1.23.0,<2 # not yet ready for numpy 2 - +liger-kernel == 0.3.1 cuda-python - -litgpt @ git+https://github.com/Lightning-AI/lit-gpt@940ffc96f7214bca24aa77479bc7c33900aaef28 +litgpt == 0.5.1 diff --git a/thunder/__init__.py b/thunder/__init__.py index 945cf39df5..45a56a16a2 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -14,7 +14,6 @@ from thunder.core.module import ThunderModule from thunder.core.interpreter import InterpreterLogItem from thunder.core.options import ( - resolve_sharp_edges_option, CACHE_OPTIONS, SHARP_EDGES_OPTIONS, ) @@ -124,6 +123,27 @@ "pytorch_executor", # debugging functions "set_execution_callback_file", + "jit", + "resolve_executors", + "add_executor_lists", + "get_executor", + "get_all_executors", + "get_default_executors", + "get_always_executors", + "compile_data", + "compile_stats", + "last_traces", + "last_backward_traces", + "cache_option", + "cache_hits", + "cache_misses", + "list_transforms", + "last_interpreter_log", + "last_interpreted_instructions", + "print_last_interpreter_log", + "last_compile_options", + "get_auto_registered_torch_op_names", + "grad", ] @@ -153,13 +173,6 @@ def __version__(): complex64 = dtypes.complex64 complex128 = dtypes.complex128 -# -# Module aliases -# - -# NOTE this allows clang.foo() to be called directly as thunder.foo() -from thunder.clang import * - # # Promoted executor-related functions and objects # diff --git a/thunder/common.py b/thunder/common.py index 92ebeac8e5..4107a00550 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -47,6 +47,17 @@ import numpy as np import thunder + +__all__ = [ + "CompileStats", + "CompileData", + "cache_put", + "cache_get", + "trace", + "transform_for_execution", + "transform_to_torch_types", +] + # # Datastructures for compiled functions # @@ -709,6 +720,6 @@ def map_to_torch(x: Any) -> Any: last = trace.bound_symbols[-1] assert last.sym.id == prims.PrimIDs.RETURN new_args = tree_map(map_to_torch, last.args) - new_bsym = prims.python_return.bind(*new_args, output=()) + new_bsym = prims.python_return.bind(*new_args, output=None) trace.bound_symbols[-1] = new_bsym return trace diff --git a/thunder/core/codeutils.py b/thunder/core/codeutils.py index 6afbeca7ac..01ad5cfec0 100644 --- a/thunder/core/codeutils.py +++ b/thunder/core/codeutils.py @@ -339,6 +339,19 @@ def _arg_printer(name: str, has_default: bool, default: Any = None) -> str: return f"def {self.name}({arg_str}):" + @staticmethod + def from_name_and_args(name: str, args: Sequence[Any]): + si = SigInfo(name) + for a in args: + if isinstance(a, ProxyInterface): + si.args.append((a.name, None)) + else: + from thunder.core.proxies import proxy + + pa = proxy(a) + si.args.append((pa.name, None)) + return si + # Creates a SigInfo object from a function and the inputs to it # The SigInfo object contains name and value information for the args, varargs, kwargs, and varkwargs diff --git a/thunder/core/interpreter.py b/thunder/core/interpreter.py index d21804d7ff..84577f3d66 100644 --- a/thunder/core/interpreter.py +++ b/thunder/core/interpreter.py @@ -1413,6 +1413,47 @@ def impl(obj, start): return _interpret_call(impl, obj, wrap_const(start)) +def _zip_lookaside(*obj: Iterable, strict=False): + + if not obj: + return + + def zip(*obj, strict=False): + # zip('ABCD', 'xy') --> Ax By + sentinel = object() + iterators = [iter(it) for it in obj] + while iterators: + result = [] + break_loop = False + for it in iterators: + elem = next(it, sentinel) + if elem is sentinel: + if not strict: + return + else: + break_loop = True + break + result.append(elem) + + if break_loop: + break + + yield tuple(result) + if result: + i = len(result) + plural = " " if i == 1 else "s 1-" + msg = f"zip() argument {i+1} is shorter than argument{plural}{i}" + raise ValueError(msg) + sentinel = object() + for i, iterator in enumerate(iterators[1:], 1): + if next(iterator, sentinel) is not sentinel: + plural = " " if i == 1 else "s 1-" + msg = f"zip() argument {i+1} is longer than argument{plural}{i}" + raise ValueError(msg) + + return _interpret_call(zip, *obj, strict=wrap_const(strict)) + + @interpreter_needs_wrap def eval_lookaside( source: str | bytes | bytearray | CodeType, # A python expression @@ -2743,6 +2784,7 @@ def _type_call_lookaside(wrapped_typ, *args, **kwargs): any: _any_lookaside, bool: _bool_lookaside, enumerate: _enumerate_lookaside, + zip: _zip_lookaside, exec: exec_lookaside, eval: eval_lookaside, getattr: _getattr_lookaside, diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 4ca21bbb2f..5a6efa434e 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -385,7 +385,7 @@ def wrapper(*args, **kwargs): def record_source_loc_in_symbol_header(fn): @wraps(fn) def wrapper(*args, **kwargs): - runtimectx: Interpreterruntimectx = get_interpreterruntimectx() + runtimectx: InterpreterRuntimeCtx = get_interpreterruntimectx() filename, positions = runtimectx.get_current_user_source_location() ctx: JitCtx = get_jit_ctx() ctx._computation_trace.set_current_source_location(filename, positions) @@ -640,14 +640,7 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar with tracectx(trace_of_fwd): prims.python_return(unwrapped_custom_forward_result) - si = SigInfo(symbol_name) - for a in unwrapped_custom_forward_args: - if isinstance(a, Proxy): - si.args.append((a.name, None)) - else: - pa = proxy(a) - si.args.append((pa.name, None)) - trace_of_fwd._siginfo = si + trace_of_fwd._siginfo = SigInfo.from_name_and_args(symbol_name, unwrapped_custom_forward_args) trace_of_fwd.args = unwrapped_custom_forward_args @wraps(trace_of_fwd.python_callable()) @@ -687,14 +680,7 @@ def bind_postprocess(bsym): trace_of_augmented_fwd.add_bound_symbol(bsym) with tracectx(trace_of_augmented_fwd): prims.python_return(augmented_bsym_output) - si = SigInfo(custom_fwd_sym.name) - for a in unwrapped_custom_forward_args: - if isinstance(a, Proxy): - si.args.append((a.name, None)) - else: - pa = proxy(a) - si.args.append((pa.name, None)) - trace_of_augmented_fwd._siginfo = si + trace_of_augmented_fwd._siginfo = SigInfo.from_name_and_args(custom_fwd_sym.name, unwrapped_custom_forward_args) trace_of_augmented_fwd.args = unwrapped_custom_forward_args @wraps(trace_of_augmented_fwd.python_callable()) @@ -739,24 +725,20 @@ def augmented_custom_forward_rule(*args, **kwargs): for bsym in custom_bwd_bsyms: trace_of_backward.add_bound_symbol(bsym) with tracectx(trace_of_backward): - prims.python_return.bind(*unwrap(custom_backward_result), output=()) + prims.python_return.bind(*unwrap(custom_backward_result), output=None) @wraps(trace_of_backward.python_callable()) def bwd_trace_callable_interface(*args, **kwargs): return thunder.core.trace_interpreter.interpret_trace(trace_of_backward, *args, **kwargs) - bwd_si = SigInfo("backward_impl") - for a in ctx_proxy.saved_consts + ctx_proxy.saved_tensors + grads: - if isinstance(a, Proxy): - bwd_si.args.append((a.name, None)) - else: - pa = proxy(a) - bwd_si.args.append((pa.name, None)) bwd_trace_impl = TraceCtx() for bsym in custom_bwd_bsyms: bwd_trace_impl.add_bound_symbol(bsym) - bwd_trace_impl.add_bound_symbol(prims.python_return.bind(*sequencify(unwrap(custom_backward_result)), output=())) - bwd_trace_impl._siginfo = bwd_si + bwd_trace_impl.add_bound_symbol(prims.python_return.bind(*sequencify(unwrap(custom_backward_result)), output=None)) + bwd_trace_impl._siginfo = SigInfo.from_name_and_args( + "backward_impl", + ctx_proxy.saved_consts + ctx_proxy.saved_tensors + grads, + ) bwd_trace_impl.args = tuple(ctx_proxy.saved_consts + ctx_proxy.saved_tensors + grads) @wraps(bwd_trace_impl.python_callable()) @@ -904,7 +886,7 @@ def general_jit_lookaside(fn, *args, **kwargs) -> None | Callable: lookaside = executor_lookaside # the ad hoc executor may be extended during compilation elif (executor_lookaside := ctx.ad_hoc_executor._lookasides.get(fn, None)) is not None: - lookaside = jit_needs_wrap(executor_lookaside) + lookaside = interpreter_needs_wrap(executor_lookaside) elif isinstance(fn, Symbol) or fn in _clang_fn_set: # Performs symbol lookasides # NOTE Symbols "lookaside" to themselves; this just prevents their internals from being jitted @@ -1007,7 +989,7 @@ def _maybe_update_proxy_name(orig_value: Any, name: str, is_internal: bool | Non } if is_internal is None: - runtimectx: Interpreterruntimectx = get_interpreterruntimectx() + runtimectx: InterpreterRuntimeCtx = get_interpreterruntimectx() frame = runtimectx.peek_frame_stack() assert frame is not None # pass is_internal if you call this before the frame is set up is_internal = frame.module in {"thunder.core.interpreter", "thunder.core.jit_ext"} @@ -1402,7 +1384,7 @@ def from_binary_subscr(provenance, *, new_output=False): output = Proxy("subscr") # name? collectify? else: output = p - if isinstance(idx, (int, str)): + if isinstance(idx, (int, str, Proxy)): if isinstance(idx, int): idx = int(idx) elif isinstance(idx, str): diff --git a/thunder/core/prims.py b/thunder/core/prims.py index 588043332c..44b0678404 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -1652,8 +1652,8 @@ def _del_impl(x: Any, /) -> None: ) -def _return_meta(*args) -> Any: - return args +def _return_meta(*args) -> None: + return None def return_printer( @@ -1674,9 +1674,8 @@ def return_printer( return f"return {arg_str}" -# NOTE This wrapper for del is necessary because python_impl=del is invalid syntax (del is not a regular function) -def _return_impl(*args) -> Any: - return args +def _return_impl(*args) -> None: + return None python_return = make_prim( diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index ae7744dda3..626f500e6c 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -3,7 +3,7 @@ import copy from enum import auto, Enum from numbers import Number -from typing import Type, Optional, Any, Tuple, List, Union +from typing import Any from collections.abc import Callable from collections.abc import Sequence @@ -1974,8 +1974,8 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple = for idx, s in enumerate(t.shape) ) else: + # NOTE Without tuple(t.shape) then the shape would be a torch.Size object shape = tuple(t.shape) - # NOTE Without tuple(t.shape) then the shape would be a torch.Size object return TensorProxy( name, shape=tuple(shape), diff --git a/thunder/core/rematerialization.py b/thunder/core/rematerialization.py index 7db7b88f64..53f60ea5ae 100644 --- a/thunder/core/rematerialization.py +++ b/thunder/core/rematerialization.py @@ -7,7 +7,7 @@ from collections import defaultdict import time -from igraph import Graph +import networkx as nx from thunder.core import prims, utils from thunder.core.baseutils import BoundSymbolInterface, ProxyInterface @@ -163,6 +163,11 @@ def apply_rematerialization_for_consumer( filter(lambda x: x.name not in map(lambda x: x.name, new_consumer_args), consumer.args) ) + # In the case where there are no tensors to rematerialize it is + # possible to terminate early and return the consumer as it was. + if not rematerialized_inputs: + return consumer + # Construct a temporary Trace object with subsymbols from the producer. trace = TraceCtx(None) trace.bound_symbols = producer.subsymbols @@ -312,12 +317,9 @@ def find_cut( # Create a graph edges = [] - name_to_id = {} - capacities = [] def add_edge(src, dst, capacity): - edges.append((name_to_id.setdefault(src, len(name_to_id)), name_to_id.setdefault(dst, len(name_to_id)))) - capacities.append(capacity) + edges.append((src, dst, {"capacity": capacity})) utils.check( len(required_consumer_vars) > 0, @@ -369,23 +371,17 @@ def add_edges(var): for var in symbol.flat_proxy_outs: add_edges(var) - g = Graph( - n=len(name_to_id), - edges=edges, - directed=True, - edge_attrs={"capacity": capacities}, - ) - source = name_to_id["source"] - sink = name_to_id["sink"] + g = nx.DiGraph() + g.add_edges_from(edges) + + _, (reachable, non_reachable) = nx.minimum_cut(g, "source", "sink") - id_to_name = dict(map(reversed, name_to_id.items())) + cut_edges = set() + for u, nbrs in ((n, g[n]) for n in reachable): + cut_edges.update((u, v) for v in nbrs if v in non_reachable) - g_edges = g.get_edgelist() - cut = g.mincut(source, sink, "capacity").cut cut_nodes = set() - for cut_edge_id in cut: - u, v = g_edges[cut_edge_id] - node_in, node_out = id_to_name[u], id_to_name[v] + for node_in, node_out in cut_edges: if node_out == "sink": continue assert node_in.endswith("_in"), node_in diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 71d6872a66..4e81bf3261 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -2887,7 +2887,7 @@ def _update_forward_with_new_saved_for_backward(forward_trace: Trace, saved_for_ saved_tensors, saved_other = _split_saved_for_backward_into_tensors_and_other(saved_for_backward) assert forward_trace.bound_symbols[-1].sym.id == prims.PrimIDs.RETURN new_return = (forward_trace.output[0], (saved_tensors, saved_other)) - forward_trace.bound_symbols[-1] = replace(forward_trace.bound_symbols[-1], args=new_return, output=new_return) + forward_trace.bound_symbols[-1] = replace(forward_trace.bound_symbols[-1], args=new_return) def _update_backward_with_new_saved_for_backward(backward_trace: Trace, saved_for_backward: Sequence[Variable]) -> None: @@ -3144,7 +3144,7 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr new_fwd_trace = from_trace(fwd_trace) new_fwd_trace.bound_symbols = fwd_trace.bound_symbols.copy() new_return_args = (fwd_trace.output[0], (new_saved_for_backward, fwd_trace.output[1][1])) - new_fwd_trace.bound_symbols[-1] = prims.python_return.bind(*new_return_args, output=()) + new_fwd_trace.bound_symbols[-1] = prims.python_return.bind(*new_return_args, output=None) new_bwd_trace = from_trace(bwd_trace) # In cases where C0 name is carried from previous trace it must be removed diff --git a/thunder/core/utils.py b/thunder/core/utils.py index edce139cc3..a5056071d5 100644 --- a/thunder/core/utils.py +++ b/thunder/core/utils.py @@ -8,12 +8,13 @@ from typing import Any, overload, Generic, Optional, TypeVar, TYPE_CHECKING from collections.abc import Callable from collections.abc import Hashable, Iterable, Iterator, Sequence +from collections import defaultdict from typing_extensions import Self import thunder.core.dtypes as dtypes from thunder.core.pytree import tree_flatten, tree_unflatten, tree_map -from thunder.core.proxies import Proxy, NumberProxy, TensorProxy, variableify, CONSTRAINT +from thunder.core.proxies import Proxy, NumberProxy, TensorProxy, variableify, CONSTRAINT, Variable from thunder.core.baseutils import * from thunder.core.codeutils import * from thunder.core.trace import TraceCtx @@ -1144,7 +1145,7 @@ def get_symbols_to_last_used_variables(symbols, ignore): ignore = (ignore,) if not isinstance(ignore, Sequence) else ignore ignore = tree_flatten(ignore)[0] variable_to_last_symbol = {} - symbol_to_last_variables = {} + symbol_to_last_variables = defaultdict(list) def _mark_last_use(symbol, variable): if variable in ignore: @@ -1157,10 +1158,10 @@ def _mark_last_use(symbol, variable): # If this function is used in the combined nvfuser+torch executor, there are no symbols but regions. # Regions do not have args, kwargs if hasattr(symbol, "inputs"): - variables = tuple(symbol.inputs) + variables = tuple(symbol.inputs) + tuple(symbol.outputs) else: - variables = (symbol.args, symbol.kwargs) - tree_map(lambda x: _mark_last_use(symbol, x) if isinstance(x, trace.Variable) else None, variables) + variables = (symbol.flat_variableified_proxy_args) + tuple(symbol.flat_variableified_proxy_outs) + tree_map(lambda x: _mark_last_use(symbol, x) if isinstance(x, Variable) else None, variables) return symbol_to_last_variables diff --git a/thunder/core/vjp_utils.py b/thunder/core/vjp_utils.py index 08b34564e5..0921bf7c6e 100644 --- a/thunder/core/vjp_utils.py +++ b/thunder/core/vjp_utils.py @@ -104,7 +104,7 @@ def find_backward_output(forward_input): prims.PrimIDs.GET_GRAD, ) backward_bsyms = [bsym for bsym in backward_bsyms if bsym.sym.id not in skip] - backward_bsyms.append(prims.python_return.bind(bw_outputs, output=())) + backward_bsyms.append(prims.python_return.bind(bw_outputs, output=None)) forward_input_proxies = tree_flatten((joint_trace.args, joint_trace.kwargs))[0] forward_input_proxies = [arg for arg in forward_input_proxies if isinstance(arg, Proxy)] @@ -131,7 +131,7 @@ def find_backward_output(forward_input): return_bsym = augmented_forward_trace.bound_symbols[-1] assert return_bsym.sym.id == PrimIDs.RETURN augmented_forward_trace.bound_symbols[-1] = prims.python_return.bind( - (joint_trace.output, saved_for_backward), output=() + (joint_trace.output, saved_for_backward), output=None ) # Remove put/get grad and backward symbols from augmented forward trace augmented_forward_trace = dce(augmented_forward_trace) @@ -156,7 +156,7 @@ def find_backward_output(forward_input): additional_saved = [o for bsym in same_bsyms for o in bsym.flat_proxy_outs] saved_for_backward += list({variableify(arg): arg for arg in additional_saved}.values()) augmented_forward_trace.bound_symbols[-1] = prims.python_return.bind( - (joint_trace.output, saved_for_backward), output=() + (joint_trace.output, saved_for_backward), output=None ) backward_params = [ diff --git a/thunder/dev_utils/debug_transform.py b/thunder/dev_utils/debug_transform.py index ea333f545a..53d12aef8c 100644 --- a/thunder/dev_utils/debug_transform.py +++ b/thunder/dev_utils/debug_transform.py @@ -12,7 +12,7 @@ def create_debug_boundsymbol(name: str, bsym: BoundSymbol, call_ctx: Callable): def bind_postprocess(debug_bsym): debug_bsym._call_ctx = {name: partial(call_ctx, debug_bsym, bsym)} - debug_sym = Symbol(name, lambda *_: None, is_prim=True, _bind_postprocess=bind_postprocess) + debug_sym = Symbol(name, lambda *_, **__: None, is_prim=True, _bind_postprocess=bind_postprocess) debug_bsym = debug_sym.bind(*bsym.args, output=None, **bsym.kwargs) return debug_bsym diff --git a/thunder/dynamo/compiler.py b/thunder/dynamo/compiler.py index 957ca65188..f55ee29657 100644 --- a/thunder/dynamo/compiler.py +++ b/thunder/dynamo/compiler.py @@ -1,14 +1,18 @@ +from __future__ import annotations from functools import partial from looseversion import LooseVersion +from typing import TYPE_CHECKING +import warnings import torch -import warnings from thunder.core.baseutils import run_once - -from thunder.dynamo.utils import SubgraphInfo, recompile_graph +from thunder.dynamo.utils import recompile_graph from thunder.dynamo.splitter import _splitter +if TYPE_CHECKING: + from thunder.dynamo.utils import SubgraphInfo + @run_once def _warn_thunder_compiler(): @@ -21,12 +25,12 @@ def _warn_thunder_compiler(): class ThunderCompiler: def __init__(self, **thunder_options): """ - A class that compiles a `fx.GraphModule` to a `thunder.ThunderModule`. - This class is meant to be used as a backend for the `torch.compile` + A class that compiles a :class:`torch.fx.GraphModule` to a :class:`thunder.ThunderModule`. + This class is meant to be used as a backend for the :func:`torch.compile` function. Keyword arguments: - thunder_options: a dictionary of options to pass to `thunder.jit`. Besides all the arguments to `thunder.jit`, + thunder_options: a dictionary of options to pass to :func:`thunder.jit`. Besides all the arguments to :func:`thunder.jit`, it accepts `torch_inductor_options` which are passed to `torch.compile` if part of the graph is not supported by thunder. @@ -44,7 +48,7 @@ def __init__(self, **thunder_options): ... return x - 1 >>> out = func(x) """ - from thunder import ThunderModule, jit + from thunder import jit _warn_thunder_compiler() diff --git a/thunder/dynamo/splitter.py b/thunder/dynamo/splitter.py index 6838d23102..e2aaf8ad50 100644 --- a/thunder/dynamo/splitter.py +++ b/thunder/dynamo/splitter.py @@ -1,4 +1,5 @@ -from collections.abc import Callable +from __future__ import annotations +from typing import TYPE_CHECKING import torch from torch.fx.passes.split_module import split_module @@ -15,6 +16,9 @@ recompile_graph, ) +if TYPE_CHECKING: + from collections.abc import Callable + def _splitter( gm: torch.fx.GraphModule, diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index 49d967a456..b0ef08097e 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -1,8 +1,10 @@ +from __future__ import annotations +from collections.abc import Callable from enum import Enum, auto +from typing import TYPE_CHECKING import dataclasses -from collections.abc import Callable -import itertools import inspect +import itertools import torch @@ -10,9 +12,17 @@ from thunder.torch import _torch_to_thunder_function_map from thunder.torch.langctx import torchctx +if TYPE_CHECKING: + from thunder.core.symbol import Symbol + auto_register_ops = set(itertools.chain(*torch_auto_registered_ops.values())) +# Currently, thunder as mapping torch these function but they +# just throw warning. +UNSUPPORTED_THUNDER_FUNCTION = (torch._C._set_grad_enabled,) + + class CompilerType(Enum): """ An enumeration representing different types of compilers. @@ -49,36 +59,38 @@ class SplitReasonType(Enum): @dataclasses.dataclass(frozen=True) class SplitReason: - """ - A dataclass containing information about a split. + """A dataclass containing information about a split. Attributes: - type (SplitReasonType): Reason for the split. - info (str): String with details of what caused the split. - exception (Exception | None): Exception if there was any. + reason_type: Reason for the split. + info: String with details of what caused the split. + exception: Exception if there was any. """ - type: SplitReasonType + reason_type: SplitReasonType info: str | None exception: Exception | None = None @dataclasses.dataclass(frozen=True) class SubgraphInfo: - """ - A dataclass containing information about a subgraph. + """A dataclass containing information about a subgraph. Attributes: - original_graph_module (torch.fx.GraphModule): The original graph module. - split_graph_module (torch.fx.GraphModule): Optional. The graph module for the split subgraph. - thunder_compiled_fns (list[Callable]): List of thunder optimized callables. This could be None if there the graph module was not supported by thunder. Look at the `split_reasons` for further information. - compiled_functions (list[CompiledFunction]): A list of compiled functions derived from the subgraph. This will be a list with one function in case the graph was not split. - split_reasons (list[SplitReason] | None): Optional list of reasons explaining why the subgraph was split. Present only if there are was a split. + original_graph_module: The original graph module. + split_graph_module: The graph module for the split subgraph. + thunder_compiled_fns: List of thunder optimized callables. + This could be :obj:`None` if there the graph module was not supported by thunder. + Look at the :attr:`split_reasons` for further information. + submodule_to_compiled_functions: Dict from subgraph to compiled function. + This will be a dict with one pair in case the graph was not split. + split_reasons: List of reasons explaining why the subgraph was split. + Present only if there are was a split. """ original_graph_module: torch.fx.GraphModule - split_graph_module: torch.fx.GraphModule - thunder_compiled_fns: list[Callable] + split_graph_module: torch.fx.GraphModule | None + thunder_compiled_fns: list[Callable] | None submodule_to_compiled_functions: dict[torch.fx.GraphModule, CompiledFunction] split_reasons: list | None = None @@ -143,7 +155,7 @@ def make_tensor_proxy(arg_node): return proxy_args, proxy_kwargs -def try_execute_thunder_symbol(thunder_symbol: "Symbol", node: torch.fx.Node) -> tuple[bool, SplitReason | None]: +def try_execute_thunder_symbol(thunder_symbol: Symbol, node: torch.fx.Node) -> tuple[bool, SplitReason | None]: """ Attempts to execute a given Thunder symbol within a tracing context, using proxies for the node's arguments. @@ -214,10 +226,29 @@ def get_nodes_in_unsupported_ctx_regions(gm: torch.fx.GraphModule) -> set[torch. # We want to mark nodes with `_enter_autocast` and `_exit_autocast` # as unsupported as `thunder` doesn't correctly deal with these stateful functions. + + def is_no_grad_ctx_enter(node): + if node.target == torch._C._set_grad_enabled: + arg: bool = node.args[0] + assert isinstance(arg, bool) + return not arg # arg is False (i.e. grad was disabled) + return False + + def is_no_grad_ctx_exit(node): + if node.target == torch._C._set_grad_enabled: + arg: bool = node.args[0] + assert isinstance(arg, bool) + return arg # arg is True (i.e. grad was enabled) + return False + for node in gm.graph.nodes: - if node.op == "call_function" and node.target in (torch.amp.autocast_mode._enter_autocast,): + if node.op == "call_function" and ( + node.target in (torch.amp.autocast_mode._enter_autocast,) or is_no_grad_ctx_enter(node) + ): ctx_cnt += 1 - elif node.op == "call_function" and node.target in (torch.amp.autocast_mode._exit_autocast,): + elif node.op == "call_function" and ( + node.target in (torch.amp.autocast_mode._exit_autocast,) or is_no_grad_ctx_exit(node) + ): ctx_cnt -= 1 else: if ctx_cnt > 0: @@ -264,6 +295,15 @@ def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason ) return False, split_reason + # These functions are present in `_torch_to_thunder_function_map` but don't mimic exact behavior. + # Eg. torch._C._set_grad_enabled's thunder implementation just throws warning that this is unsupported. + if target in UNSUPPORTED_THUNDER_FUNCTION: + split_reason = SplitReason( + SplitReasonType.UNSUPPORTED_NODE, + info=f"node with name: {node.name} and target: {node.target} has been manually disabled.", + ) + return False, split_reason + # If thunder has a mapping for this operation, try executing the meta function and see. # We have a symbol for `torch.where`, but we don't support one overload of it. # So, we try and execute the meta to get a real signal. @@ -303,7 +343,10 @@ def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason def update_node_and_submodule( - graph_module: torch.fx.GraphModule, node: torch.fx.Node, new_name: str, new_callable: Callable + graph_module: torch.fx.GraphModule, + node: torch.fx.Node, + new_name: str, + new_callable: Callable, ): """ Updates the graph module and the node in place with a new name and a new callable as the target. diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 571a55adff..c85ad67cb2 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -675,7 +675,7 @@ def has_cuda_input_or_output(self, bsym: BoundSymbol) -> bool: def _dce_bsyms(self, input_list, output, bsyms: list[BoundSymbol]) -> list[BoundSymbol]: trace = TraceCtx(None) trace.bound_symbols = bsyms - bsyms.append(prims.python_return.bind(output, output=())) + bsyms.append(prims.python_return.bind(output, output=None)) needed_proxies: set[Variable] = set() trace = dce(trace, needed_proxies) # update the input_list by removing the unused inputs @@ -787,7 +787,7 @@ def map_redundant(x: Any) -> Any: return_bsym = cse_trace.bound_symbols[-1] assert return_bsym.sym.id == prims.PrimIDs.RETURN trace_output = tree_map(map_redundant, return_bsym.args) - cse_trace.bound_symbols[-1] = prims.python_return.bind(*trace_output, output=()) + cse_trace.bound_symbols[-1] = prims.python_return.bind(*trace_output, output=None) end_time_ns = time.perf_counter_ns() elapsed_time_ns = end_time_ns - start_time_ns diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py index ef2932ac3d..3e7e7ed419 100644 --- a/thunder/executors/torch_compile.py +++ b/thunder/executors/torch_compile.py @@ -69,7 +69,7 @@ def make_compiled( region_trace.bound_symbols = list(bsyms) region_trace.args = sorted_unique_inputs region_trace.kwargs = {} - region_trace.bound_symbols.append(prims.python_return.bind(sorted_unique_outputs, output=())) + region_trace.bound_symbols.append(prims.python_return.bind(sorted_unique_outputs, output=None)) def torch_interpreted_func(*args): return eval_trace(region_trace, *args, symbol_mapper=to_torch_translator) diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index 3449d0292b..345eabd921 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -2916,3 +2916,53 @@ def test_user_module_is_freed(): del mod del opt_mod assert ref_mod() is None + + +@pytest.mark.parametrize("requires_grad", [True, False]) +def test_return_bsym_has_none_output(requires_grad): + def fn(x): + return x + 1 + + x = torch.tensor([3.0], requires_grad=requires_grad) + jfn = thunder.jit(fn) + jfn(x) + + for trace in thunder.last_traces(jfn): + return_bsym = trace.bound_symbols[-1] + assert return_bsym.sym.id == thunder.prims.PrimIDs.RETURN + assert return_bsym.output is None + + if requires_grad: + for trace in thunder.last_backward_traces(jfn): + return_bsym = trace.bound_symbols[-1] + assert return_bsym.sym.id == thunder.prims.PrimIDs.RETURN + assert return_bsym.output is None + + +def test_indexing_with_hashable_object(): + class HashableClass: + def __hash__(self): + return id(self) + + h = HashableClass() + d = {h: 1, 1: 0} + + def fn(): + return d[h] + + jfn = thunder.jit(fn) + assert jfn() == 1 + assert thunder.cache_misses(jfn) == 1 # Due to first compilation. + + # Call jfn with no changes + # this should be cache hit. + assert jfn() == 1 + assert thunder.cache_hits(jfn) == 1 + assert thunder.cache_misses(jfn) == 1 + + # Change the value of the captured dict. + # This should be a cache miss, verify that. + d[h] = 2 + assert jfn() == 2 # Verify that jfn now returns 2 + assert thunder.cache_hits(jfn) == 1 + assert thunder.cache_misses(jfn) == 2 diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index afc2e9e29d..a667405a13 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -435,3 +435,33 @@ def test_thundercompiler_optim_step(executor, device, dtype, optim): tuple(ref_model.parameters()), msg=lambda s: f"{i+1}-iter {s}", ) + + +@instantiate(dtypes=NOTHING, executors=[DynamoThunderExecutor]) +def test_no_grad_ctx_manager(executor, device: str, dtype: dtypes.dtype): + backend = ThunderCompiler() + + def func(x): + with torch.no_grad(): + with torch.autocast("cuda", dtype=torch.bfloat16): + y = x @ x + return y + x + + x = torch.randn(3, 3, device=device, dtype=dtype, requires_grad=True) + actual = torch.compile(func, backend=backend)(x) + expected = torch.compile(func, backend="eager")(x) + + # We record the GraphModules that was compiled by ThunderCompiler + assert len(backend.subgraph_infos) == 1 + + for subgraph_info in backend.subgraph_infos: + assert len(subgraph_info.split_reasons) > 1 # Verify there were splits in the subgraph. + assert isinstance(subgraph_info.original_graph_module, torch.fx.GraphModule) + assert any("has been manually disabled" in split_reason.info for split_reason in subgraph_info.split_reasons) + + torch.testing.assert_close(actual, expected) + + g = torch.randn_like(actual) + actual_grad = torch.autograd.grad(actual, x, g) + expected_grad = torch.autograd.grad(expected, x, g) + torch.testing.assert_close(actual_grad, expected_grad) diff --git a/thunder/tests/test_examine_memory.py b/thunder/tests/test_examine_memory.py index cd02b8ab5a..982d9d8ff3 100644 --- a/thunder/tests/test_examine_memory.py +++ b/thunder/tests/test_examine_memory.py @@ -98,6 +98,11 @@ def bar2(a, b): # [5,2], [2,2] @requiresCUDA def test_nanogpt_block(): + # The estimated memory usage is not the same as actual peak memory usage on Hopper + if torch.cuda.get_device_capability() >= (9, 0): + pytest.skip( + f"the estimated memory usage is not the same as actual peak memory usage on {torch.cuda.get_device_name()}" + ) import thunder.tests.nanogpt_model as nanogpt_model config = nanogpt_model.GPTConfig(dropout=0) diff --git a/thunder/tests/test_interpreter.py b/thunder/tests/test_interpreter.py index fbb3a32ee1..e69dc2fff3 100644 --- a/thunder/tests/test_interpreter.py +++ b/thunder/tests/test_interpreter.py @@ -1691,6 +1691,42 @@ def fn(): jit(fn)() +def test_zip_lookaside(jit): + import re + + jitting = False + + def foo(*a, strict=False): + return list(zip(*a, strict=strict)) + + jfoo = jit(foo) + jitting = False + + res1 = foo([1, 2, 3], [4, 5, 6]) + res2 = foo([1, 2, 3], [4, 5, 6], [7, 8, 9]) + res3 = foo([1, 2], [4, 5, 6]) + res4 = foo("abc", "xyz") + # , match="zip() argument 2 is longer than argument 1" + + with pytest.raises(ValueError, match=re.escape("zip() argument 2 is longer than argument 1")): + res5 = foo([1, 2], [4, 5, 6], strict=True) + + jitting = True + jres1 = jfoo([1, 2, 3], [4, 5, 6]) + jres2 = jfoo([1, 2, 3], [4, 5, 6], [7, 8, 9]) + jres3 = jfoo([1, 2], [4, 5, 6]) + jres4 = jfoo("abc", "xyz") + + # , match=" zip() argument 2 is longer than argument 1" + with pytest.raises(ValueError, match=re.escape("zip() argument 2 is longer than argument 1")): + jres5 = jfoo([1, 2], [4, 5, 6], strict=True) + + assert res1 == jres1 + assert res2 == jres2 + assert res3 == jres3 + assert res4 == jres4 + + def test_enumerate_lookaside(jit): jitting = False diff --git a/thunder/tests/test_nvfuser.py b/thunder/tests/test_nvfuser.py index 2224bfc9ba..f6e68f0e23 100644 --- a/thunder/tests/test_nvfuser.py +++ b/thunder/tests/test_nvfuser.py @@ -210,7 +210,7 @@ def foo(a, x): assert len(fusions[0].subsymbols) == 3 # Verifies the intermediate consumer - assert fusions[1].subsymbols[-2].args[0].name == "g" + assert fusions[1].subsymbols[-1].args[0].name == "g" @instantiate(executors=(nvFuserExecutor,), dtypes=(thunder.float32,)) @@ -281,16 +281,13 @@ def func(x): # There are two nvfuser fusion groups separated by the matmul operation. assert len(fusion_bsyms) == 2 - nvf_0, nvf_1 = fusion_bsyms # CSE removes the redundant (t0 + 5) operation - assert len(nvf_0.subsymbols) == 5 - # Return t0 and t1 from the first fusion - assert [t.name for t in tree_flatten(nvf_0.output)[0]] == ["t1", "t4"] + nvf_0, nvf_1 = fusion_bsyms + assert len(nvf_0.subsymbols) + len(nvf_1.subsymbols) == 7 - # CSE does not change the second fusion - assert len(nvf_1.subsymbols) == 2 - assert [t.name for t in tree_flatten(nvf_1.output)[0]] == ["t10"] + outside_fusion_syms = ["unpack_trivial", "matmul", "python_return", "python_del"] + assert {el.sym.name for el in fw_trace.bound_symbols if not el.sym.is_fusion} == set(outside_fusion_syms) @instantiate(dtypes=NOTHING, devicetypes=(devices.DeviceType.CUDA,), executors=(nvFuserExecutor,)) diff --git a/thunder/tests/test_nvfuser_remat.py b/thunder/tests/test_nvfuser_remat.py index d0f3462391..d5b57320fa 100644 --- a/thunder/tests/test_nvfuser_remat.py +++ b/thunder/tests/test_nvfuser_remat.py @@ -19,6 +19,7 @@ from thunder.examine import get_fusions from thunder.tests.framework import instantiate, NOTHING, nvFuserExecutor, TorchExecutor, requiresCUDA from thunder.tests.make_tensor import make_tensor +import thunder.torch as ltorch @value_and_grad @@ -199,6 +200,40 @@ def test_apply_rematerialization_consumer(executor, device, _): assert tuple(new_consumer.subsymbols) == tuple(new_consumer_case2.subsymbols) +@instantiate( + dtypes=NOTHING, + executors=(nvFuserExecutor,), +) +@disable_rematerialization_in_nvfuser_fusion +def test_apply_rematerialization_consumer_early_exit(executor, device, _): + @value_and_grad + def foo(t0): + t1 = ttorch.exp(t0) + t2 = ttorch.matmul(t1, t1) + return t2 + + t0 = make_tensor(2, 2, dtype=torch.float32, device=device) + initial_trace = thunder.trace()(foo, t0) + compiled_func = thunder.jit(initial_trace.python_callable()) + _ = compiled_func(t0) + traces = thunder.last_traces(compiled_func) + trace = traces[-1] + nvfuser_symbols = tuple(filter(lambda x: x.sym.name.startswith("nvFusion"), trace.bound_symbols)) + assert len(nvfuser_symbols) == 2 + + producer = nvfuser_symbols[0] + consumer = nvfuser_symbols[1] + + # Create a cut that has t0 as extra information and + # that contains all arguments(t2) from consumer. + cut = ("t0", "t2") + new_consumer = apply_rematerialization_for_consumer(producer, consumer, cut) + + # Check that the new consumer is the old consumer + assert id(new_consumer) == id(consumer) + assert tuple(new_consumer.subsymbols) == tuple(consumer.subsymbols) + + @instantiate( dtypes=NOTHING, executors=(nvFuserExecutor,), diff --git a/thunder/tests/test_transforms.py b/thunder/tests/test_transforms.py index e8c1fe7558..f0514b1e84 100644 --- a/thunder/tests/test_transforms.py +++ b/thunder/tests/test_transforms.py @@ -497,3 +497,98 @@ def forward(self, x): original_model.load_state_dict(rename_state_dict, strict=False) litgpt_lora_output = original_model(x) assert_close(actual, litgpt_lora_output, atol=2e-1, rtol=2e-1) + + +def test_constant_folding(): + + # Helper to verify we see the expected constant tensors + # in exec_trace. + def assert_in_trace(exec_trace, sym, arg_vals): + for bsym in exec_trace.bound_symbols: + if bsym.sym.id == sym and bsym.args == arg_vals: + return + + err = f"Expected to find symbol {sym} with arguments {arg_vals} in execution trace but didn't find any." + raise RuntimeError(err) + + from thunder.transforms.constant_folding import ConstantFolding + + def forward(): + const_t = torch.tensor([2]) + getitem = (const_t * 2)[0] + return (getitem, const_t) # (4, [2]) + + jforward = thunder.jit(forward, transforms=[ConstantFolding()]) + actual = jforward() + expected = forward() + torch.testing.assert_close(actual, expected) + exec_trace = thunder.last_traces(jforward)[-1] + assert_in_trace(exec_trace, "tensor", ([2],)) + assert_in_trace(exec_trace, "full", ((), 4)) + + def forward(x): + const_t = torch.tensor([2]) + getitem = const_t[0] # 2 + getitem_2 = ( + torch.zeros( + 2, + ) + + 1 + )[ + 0 + ] # 1 + return x + getitem + getitem_2 + + jforward = thunder.jit(forward, transforms=[ConstantFolding()]) + x = torch.randn(3, 3) + actual = jforward(x) + expected = forward(x) + torch.testing.assert_close(actual, expected) + exec_trace = thunder.last_traces(jforward)[-1] + assert_in_trace(exec_trace, "full", ((), 2)) + assert_in_trace(exec_trace, "full", ((), 1.0)) + + def forward(x): + const_t = torch.tensor([2], dtype=torch.float16) + ones_t = torch.ones(1, dtype=torch.float32) + s1 = const_t * 2 # 4 + s2 = const_t / 1 # 2 + s3 = s1 * s2 + 10 # 18 + ones_mul_10 = ones_t * 10 # 10 + return x[0, 0] + s3 + ones_mul_10 + + jforward = thunder.jit(forward, transforms=[ConstantFolding()]) + x = torch.randn(3, 3) + actual = jforward(x) + expected = forward(x) + torch.testing.assert_close(actual, expected) + exec_trace = thunder.last_traces(jforward)[-1] + assert_in_trace(exec_trace, "tensor", ([18.0],)) + assert_in_trace(exec_trace, "tensor", ([10.0],)) + + # Constant folding of Python constants. + def forward(x): + t = torch.tensor(2.0) + return x + (t.item() + (t + 1).item()) + + jforward = thunder.jit(forward, transforms=[ConstantFolding()]) + x = torch.randn(3, 3) + actual = jforward(x) + expected = forward(x) + torch.testing.assert_close(actual, expected) + exec_trace = thunder.last_traces(jforward)[-1] + # exec_trace will look something like this + # def computation(x): + # # x: "cpu f32[3]" + # t5 = torch.add(x, 5.0, alpha=1) # t5: "cpu f32[3]" + # # t5 = ltorch.add(x, 5.0, alpha=1) # t5: "cpu f32[3]" + # # t5 = prims.add(x, 5.0) # t5: "cpu f32[3]" + # return t5 + + # So we check that torch.add has 5.0 in it's arguments. + for bsym in exec_trace.bound_symbols: + if bsym.sym.id == "add": + assert bsym.args[1] == 5.0 + break + else: + raise RuntimeError("Failed to find `add` symbol in trace") diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 51f8c62d54..45e2874942 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -1852,9 +1852,9 @@ def silu(a: TensorLike, /, inplace: bool = False) -> TensorLike: @torchsymbol(torch.add, is_method=True) def add( - a: NumberLike | TensorLike, b: NumberLike | TensorLike, /, *, alpha: None | Number | TensorLike = None + a: NumberLike | TensorLike, b: NumberLike | TensorLike, /, *, alpha: Number | TensorLike = 1 ) -> Number | TensorLike: - if alpha is not None: + if isinstance(alpha, TensorProxy) or alpha != 1: b = b * alpha return clang.add(a, b) @@ -1866,7 +1866,7 @@ def add_( b: NumberLike | TensorLike, /, *, - alpha: None | Number | TensorLike = None, + alpha: Number | TensorLike = 1, ) -> TensorLike: return prims.copy_(add(a, b, alpha=alpha), a) @@ -2144,15 +2144,15 @@ def remainder_(a, b, /): @torchsymbol(torch.sub, is_method=True) -def sub(a, b, /, *, alpha=None): - if alpha is not None: +def sub(a, b, /, *, alpha: NumberLike | TensorLike = 1): + if isinstance(alpha, TensorProxy) or alpha != 1: b = b * alpha return clang.sub(a, b) @torchsymbol(torch.Tensor.sub_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) -def sub_(a, b, /, *, alpha=None): +def sub_(a, b, /, *, alpha: NumberLike | TensorLike = 1): return prims.copy_(sub(a, b, alpha=alpha), a) diff --git a/thunder/transforms/constant_folding.py b/thunder/transforms/constant_folding.py new file mode 100644 index 0000000000..6f5273c4a1 --- /dev/null +++ b/thunder/transforms/constant_folding.py @@ -0,0 +1,166 @@ +from numbers import Number +from typing import Any +from collections.abc import Callable + +import torch + +import thunder +from thunder.core.trace import from_trace, tracectx +from thunder.core.proxies import variableify, Variable, TensorProxy, NumberProxy, proxy +from thunder.core.symbol import BoundSymbol +from thunder.core.dtypes import to_dtype +from thunder.core.devices import to_device +from thunder.torch import _torch_to_thunder_function_map +from thunder.core.utils import get_symbols_to_last_used_variables + +_thunder_to_torch_function_map = {v: k for k, v in _torch_to_thunder_function_map.items()} + +# Factory functions whose value we know. +TENSOR_FACTORY = ( + thunder.torch.tensor.id, + thunder.torch.ones.id, + thunder.torch.zeros.id, +) + + +def get_python_operator(bsym) -> None | Callable: + from thunder.executors.pythonex import ex as pythonex + + if pythonex.can_execute(bsym): + # TODO - Is there a better way to do the same? + # This seems brittle and tailored towards + # current implementation of pythonex. + impl = pythonex.implmap[bsym.sym.id] + module = impl.symbol.module + op = getattr(module, impl.symbol.id) + return op + return None + + +def compute_with_constant_tensors(bsym, const_values) -> None | Any: + """ + This function is used to compute the concrete output of the computation + represented by BoundSymbol if it's inputs are known to be constant. + + To run the computation, it will use PyTorch eager functions + from _torch_to_thunder_function_map or registered operator from + `pythonex` executor. + """ + + def materialize_args(a): + if isinstance(a, (TensorProxy, NumberProxy)): + return const_values[variableify(a)] + elif isinstance(a, NumberProxy): + return a.value + return a + + new_args = tuple(map(materialize_args, bsym.args)) + new_kwargs = {k: materialize_args(v) for k, v in bsym.kwargs.items()} + + # Try to see if the symbol is torch function + torch_fn = _thunder_to_torch_function_map.get(bsym.sym, None) + if torch_fn is not None: + return torch_fn(*new_args, **new_kwargs) + + # Try to see if the symbol is a Python function + python_fn = get_python_operator(bsym) + if python_fn is not None: + return python_fn(*new_args, **new_kwargs) + return None + + +class ConstantFolding(thunder.Transform): + def transform_traces_pre_prologue(self, prologue_trc, computation_trc, epilogue_trc, **kwargs): + # print(computation_trc) + # Create a new trace + const_folded_trace = from_trace(computation_trc) + const_folded_trace.bound_symbols = computation_trc.bound_symbols + + const_values: dict[Variable, torch.Tensor | Number] = {} + + # Tag output from factory functions as constant value. + for bsym in const_folded_trace.bound_symbols: + if bsym.sym.id in TENSOR_FACTORY: + torch_fn = _thunder_to_torch_function_map[bsym.sym] + t = torch_fn(*bsym.args, **bsym.kwargs) + const_values[variableify(bsym.output)] = t + + new_bsyms = [] + symbol_to_last_used_variables = get_symbols_to_last_used_variables(const_folded_trace.bound_symbols, ignore=()) + + def is_constant(proxy): + if isinstance(proxy, TensorProxy) and variableify(proxy) in const_values: + return True + elif isinstance(proxy, NumberProxy) and variableify(proxy) in const_values: + return True + elif isinstance(proxy, NumberProxy) and proxy.is_static_constrained(): + return True + return False + + const_number_swapmap = {} + for bsym in const_folded_trace.bound_symbols: + # If bsym has constant inputs, try to compute the output. + if all(map(is_constant, bsym.flat_proxy_args)) and bsym.sym.id not in TENSOR_FACTORY: + if bsym.flat_args == []: # eg, unpack_trivial + continue + new_concrete_output = compute_with_constant_tensors(bsym, const_values) + if ( + new_concrete_output is not None + ): # Might happen for `python_return` as it won't have mapping in `_thunder_to_torch_map`. + + # Create a new symbol with same output proxy but which will now represent the computed constant value. + # eg. + # known_tensor = torch.tensor(2) + # t = known_tensor + 1 --> t = torch.tensor(3) + + # For `ndim==0`, we need to use full as `tensor_from_sequence` expects + # a sequence (and not plain numbers). + if isinstance(new_concrete_output, Number): + const_number_swapmap[variableify(bsym.output)] = new_concrete_output + new_bsym = bsym + elif new_concrete_output.ndim == 0: + isinstance(new_concrete_output, torch.Tensor) + new_bsym = BoundSymbol( + thunder.prims.full, + args=( + new_concrete_output.shape, + new_concrete_output.tolist(), + ), + kwargs={ + "dtype": to_dtype(new_concrete_output.dtype), + "device": to_device(new_concrete_output.device), + }, + output=bsym.output, + ) + else: + assert isinstance(new_concrete_output, torch.Tensor) + new_bsym = BoundSymbol( + thunder.prims.tensor_from_sequence, + args=(new_concrete_output.tolist(),), + kwargs={ + "dtype": to_dtype(new_concrete_output.dtype), + "device": to_device(new_concrete_output.device), + }, + output=bsym.output, + ) + new_bsyms.append(new_bsym) + + # Update const_tensors (so that usage of the output of this symbol will also be used for further computation.) + const_values[variableify(bsym.output)] = new_concrete_output + + # Clear tensors which won't be used further. + for proxy_v in symbol_to_last_used_variables[bsym]: + const_values.pop(proxy_v, None) + + continue + + # BoundSymbol with non-constant inputs, keep it as-is + new_bsyms.append(bsym) + + # Update all input NumberProxies by constant numbers if possible. + const_folded_trace.bound_symbols = [ + bsym.from_bsym_swap_proxies(const_number_swapmap, skip_output=True) for bsym in new_bsyms + ] + + const_folded_trace.set_provenance("Constant Folding pass") + return prologue_trc, const_folded_trace, epilogue_trc diff --git a/thunder/transforms/cudagraph.py b/thunder/transforms/cudagraph.py index 4a35d23941..e66074d117 100644 --- a/thunder/transforms/cudagraph.py +++ b/thunder/transforms/cudagraph.py @@ -159,7 +159,7 @@ def region_fn(): region_trace.bound_symbols = bsyms region_trace.args = inputs region_trace.kwargs = {} - region_trace.bound_symbols.append(prims.python_return.bind(outputs, output=())) + region_trace.bound_symbols.append(prims.python_return.bind(outputs, output=None)) return region_trace.python_callable() def make_cuda_graph_callable_from_symbols(