diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml index b8d87e466d..4dd1365477 100644 --- a/.github/workflows/docs-build.yml +++ b/.github/workflows/docs-build.yml @@ -15,11 +15,51 @@ defaults: jobs: docs-make: - uses: Lightning-AI/utilities/.github/workflows/check-docs.yml@v0.11.0 - with: - python-version: "3.10" - requirements-file: "requirements/docs.txt" - install-tex: true + if: github.event.pull_request.draft == false + runs-on: ubuntu-22.04 + strategy: + fail-fast: false + matrix: + target: ["html", "doctest", "linkcheck"] + env: + ARTIFACT_DAYS: 0 + PYPI_LOCAL_DIR: "pypi_pkgs/" + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.10" + + - name: Pull sphinx template + run: | + pip install -q "awscli >=1.30.0" + aws s3 sync --no-sign-request s3://sphinx-packages/ ${PYPI_LOCAL_DIR} + pip install lai-sphinx-theme -U -f ${PYPI_LOCAL_DIR} + - name: Install pandoc + timeout-minutes: 5 + run: | + sudo apt-get update --fix-missing + sudo apt-get install -y pandoc + - name: Install package & dependencies + timeout-minutes: 20 + run: pip install . -U -r requirements/docs.txt + + - name: Make ${{ matrix.target }} + working-directory: docs/ + # allow failing link check and doctest if you run with dispatch + continue-on-error: ${{ matrix.target == 'doctest' || matrix.target == 'linkcheck' }} + run: make ${{ matrix.target }} --debug --jobs $(nproc) SPHINXOPTS="-W --keep-going" + + - name: Keep artifact + if: github.event_name == 'pull_request' + run: echo "ARTIFACT_DAYS=7" >> $GITHUB_ENV + - name: Upload built docs + if: ${{ matrix.target == 'html' }} + uses: actions/upload-artifact@v4 + with: + name: docs-html-${{ github.sha }} + path: docs/build/html/ + retention-days: ${{ env.ARTIFACT_DAYS }} deploy-docs: needs: docs-make @@ -28,7 +68,7 @@ jobs: env: GCP_TARGET: "gs://lightning-docs-thunder" steps: - - uses: actions/download-artifact@v3 + - uses: actions/download-artifact@v4 with: name: docs-html-${{ github.sha }} path: docs/build/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f16e1ef98f..b3e1716d00 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,7 @@ ci: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: end-of-file-fixer - id: trailing-whitespace @@ -24,7 +24,7 @@ repos: - id: detect-private-key - repo: https://github.com/asottile/pyupgrade - rev: v3.11.1 + rev: v3.15.2 hooks: - id: pyupgrade args: ["--py310-plus"] @@ -32,14 +32,14 @@ repos: exclude: "examples|thunder/tests/test_interpreter.py|thunder/tests/test_jit_general.py" - repo: https://github.com/codespell-project/codespell - rev: v2.2.5 + rev: v2.2.6 hooks: - id: codespell additional_dependencies: [tomli] #args: ["--write-changes"] # uncomment if you want to get automatic fixing - repo: https://github.com/psf/black - rev: 23.9.1 + rev: 24.3.0 hooks: - id: black name: Black code @@ -61,7 +61,7 @@ repos: - id: sphinx-lint - repo: https://github.com/asottile/yesqa - rev: v1.4.0 + rev: v1.5.0 hooks: - id: yesqa diff --git a/Makefile b/Makefile index 677ca6b1b9..2c27c12f5c 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,13 @@ test: clean python -m coverage run --source thunder -m pytest thunder tests -v python -m coverage report -docs: clean +sphinx-theme: + pip install -q awscli + mkdir -p dist/ + aws s3 sync --no-sign-request s3://sphinx-packages/ dist/ + pip install lai-sphinx-theme -f dist/ + +docs: clean sphinx-theme pip install -e . --quiet -r requirements/docs.txt -f https://download.pytorch.org/whl/cpu/torch_stable.html cd docs ; python -m sphinx -b html -W --keep-going source build diff --git a/README.md b/README.md index a672abe48e..3e33bf5bd4 100644 --- a/README.md +++ b/README.md @@ -14,8 +14,9 @@ ______________________________________________________________________ Get startedInstallExamples • - Features • - Documentation • + Inside Thunder • + Get involved! • + Documentation

[![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/Lightning-AI/lightning-thunder/blob/main/LICENSE) @@ -30,41 +31,58 @@ ______________________________________________________________________ **Thunder makes PyTorch models Lightning fast.** -Thunder is a source-to-source compiler for PyTorch. It makes PyTorch programs faster by combining and using different hardware executors at once (ie: nvFuser, torch.compile, cuDNN, and TransformerEngine FP8). +Thunder is a source-to-source compiler for PyTorch. It makes PyTorch programs faster by combining and using different hardware executors at once (for instance, [nvFuser](https://github.com/NVIDIA/Fuser), [torch.compile](https://pytorch.org/docs/stable/torch.compiler.html), [cuDNN](https://developer.nvidia.com/cudnn), and [TransformerEngine FP8](https://github.com/NVIDIA/TransformerEngine)). -Works on single accelerators and in multi-GPU settings. +It supports both single and multi-GPU configurations. Thunder aims to be usable, understandable, and extensible. -## Performance +  -Thunder can achieve significant speedups over standard PyTorch eager code, through the compounding effects of optimizations and the use of best-in-class executors. Here is an example of the pretraining throughput for Llama 2 7B as implemented in [LitGPT](https://github.com/Lightning-AI/litgpt). +> \[!Note\] +> Lightning Thunder is in alpha. Feel free to get involved, but expect a few bumps along the way. + +  + +## Single-GPU performance + +Thunder can achieve significant speedups over standard non-compiled PyTorch code ("PyTorch eager"), through the compounding effects of optimizations and the use of best-in-class executors. The figure below shows the pretraining throughput for Llama 2 7B as implemented in [LitGPT](https://github.com/Lightning-AI/litgpt).
Thunder
-Thunder achieves a 40% speedup in training throughput compared to eager code on H100 using a combination of executors including nvFuser, torch.compile, cuDNN, and TransformerEngine FP8. +As shown in the plot above, Thunder achieves a 40% speedup in training throughput compared to eager code on H100 using a combination of executors including nvFuser, torch.compile, cuDNN, and TransformerEngine FP8. + +  -Thunder supports distributed strategies like DDP and FSDP (ZeRO2 and ZeRO3). Here is the normalized throughput measured for Llama 2 7B (this time without FP8 mixed precision, support for FSDP is underway). +## Multi-GPU performance + +Thunder also supports distributed strategies such as DDP and FSDP for training models on multiple GPUs. The following plot displays the normalized throughput measured for Llama 2 7B without FP8 mixed precision; support for FSDP is in progress.
Thunder
-**NOTE: Lightning Thunder is alpha.** Feel free to get involved, expect a few bumps along the way. +  ## Get started -Try Thunder without installing by using our [Zero to Thunder Tutorial Studio](https://lightning.ai/lightning-ai/studios/zero-to-thunder-tutorial). +The easiest way to get started with Thunder, requiring no extra installations or setups, is by using our [Zero to Thunder Tutorial Studio](https://lightning.ai/lightning-ai/studios/zero-to-thunder-tutorial). + +  ## Install Thunder -Install [nvFuser](https://github.com/NVIDIA/Fuser) nightly, and Thunder together +To use Thunder on your local machine, first install [nvFuser](https://github.com/NVIDIA/Fuser) nightly and PyTorch nightly together as follows: ```bash # install nvFuser which installs the matching nightly PyTorch pip install --pre 'nvfuser-cu121[torch]' --extra-index-url https://pypi.nvidia.com +``` +Then, install Thunder as follows: + +``` # install thunder pip install lightning-thunder ``` @@ -73,26 +91,60 @@ pip install lightning-thunder Advanced install options +  + ### Install from main +Alternatively, you can install the latest version of Thunder directly from this GitHub repository as follows: + +``` +# 1) Install nvFuser and PyTorch nightly dependencies: +pip install --pre 'nvfuser-cu121[torch]' --extra-index-url https://pypi.nvidia.com +``` + ```bash +# 2) Install Thunder itself pip install git+https://github.com/Lightning-AI/lightning-thunder.git ``` +  + ### Install to tinker and contribute -Install this way to tinker with the internals and contribute: +If you are interested in tinkering with and contributing to Thunder, we recommend cloning the Thunder repository and installing it in pip's editable mode: ```bash +git clone https://github.com/Lightning-AI/lightning-thunder.git +cd lightning-thunder pip install -e . ``` +  + +### Develop and run tests + +After cloning the lightning-thunder repository and installing it as an editable package as explained above, ou can set up your environment for developing Thunder by installing the development requirements: + +```bash +pip install -r requirements/devel.txt +``` + +Now you run tests: + +```bash +pytest thunder/tests +``` + +Thunder is very thoroughly tested, so expect this to take a while. + +  + ## Hello World -Here is a simple example of how Thunder lets you compile and run PyTorch code: +Below is a simple example of how Thunder allows you to compile and run PyTorch code: ```python import torch @@ -120,15 +172,19 @@ print(result) The compiled function `jfoo` takes and returns PyTorch tensors, just like the original function, so modules and functions compiled by Thunder can be used as part of larger PyTorch programs. +  + ## Train models Thunder is in its early stages and should not be used for production runs yet. -However, it can already deliver outstanding performance on LLM model supported by [LitGPT](https://github.com/Lightning-AI/lit-gpt), such as Mistral, Llama 2, Gemma, Falcon, and others. +However, it can already deliver outstanding performance for pretraining and finetuning LLMs supported by [LitGPT](https://github.com/Lightning-AI/lit-gpt), such as Mistral, Llama 2, Gemma, Falcon, and others. Check out [the LitGPT integration](https://github.com/Lightning-AI/litgpt/tree/main/extensions/thunder) to learn about running LitGPT and Thunder together. -## Features +  + +## Inside Thunder: A brief look at the core features Given a Python callable or PyTorch module, Thunder can generate an optimized program that: @@ -140,13 +196,13 @@ Given a Python callable or PyTorch module, Thunder can generate an optimized pro To do so, Thunder ships with: - A JIT for acquiring Python programs targeting PyTorch and custom operations -- A multi-level IR to represent operations as a trace of a reduced op-set -- An extensible set of transformations on the trace, such as `grad`, fusions, distributed (like `ddp`, `fsdp`), functional (like `vmap`, `vjp`, `jvp`) +- A multi-level intermediate representation (IR) to represent operations as a trace of a reduced operation set +- An extensible set of transformations on the trace of a computational graph, such as `grad`, fusions, distributed (like `ddp`, `fsdp`), functional (like `vmap`, `vjp`, `jvp`) - A way to dispatch operations to an extensible collection of executors Thunder is written entirely in Python. Even its trace is represented as valid Python at all stages of transformation. This allows unprecedented levels of introspection and extensibility. -Thunder doesn't generate code for accelerators directly. It acquires and transforms user programs so that it's possible to optimally select or generate device code using fast executors like: +Thunder doesn't generate code for accelerators, such as GPUs, directly. It acquires and transforms user programs so that it's possible to optimally select or generate device code using fast executors like: - [torch.compile](https://pytorch.org/get-started/pytorch-2.0/) - [nvFuser](https://github.com/NVIDIA/Fuser) @@ -159,6 +215,8 @@ Thunder doesn't generate code for accelerators directly. It acquires and transfo Modules and functions compiled with Thunder fully interoperate with vanilla PyTorch and support PyTorch's autograd. Also, Thunder works alongside torch.compile to leverage its state-of-the-art optimizations. +  + ## Documentation Docs are currently not hosted publicly. However you can build them locally really quickly: @@ -169,27 +227,15 @@ make docs and point your browser to the generated docs at `docs/build/index.html`. -## Develop and run tests - -You can set up your environment for developing Thunder by installing the development requirements: - -```bash -pip install -r requirements/devel.txt -``` +  -Install Thunder as an editable package (optional): - -```bash -pip install -e . -``` +## Get involved! -Now you run tests: +We appreciate your feedback and contributions. If you have feature requests, questions, or want to contribute code or config files, please don't hesitate to use the [GitHub Issue](https://github.com/Lightning-AI/lightning-thunder/issues) tracker. -```bash -pytest thunder/tests -``` +We welcome all individual contributors, regardless of their level of experience or hardware. Your contributions are valuable, and we are excited to see what you can accomplish in this collaborative and supportive environment. -Thunder is very thoroughly tested, so expect this to take a while. +  ## License diff --git a/docs/source/conf.py b/docs/source/conf.py index 598dc42053..6c410f8e05 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -11,15 +11,13 @@ # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -import glob import inspect import os -import re import shutil import sys from importlib.util import module_from_spec, spec_from_file_location -import pt_lightning_sphinx_theme +import lai_sphinx_theme _PATH_HERE = os.path.abspath(os.path.dirname(__file__)) _PATH_ROOT = os.path.realpath(os.path.join(_PATH_HERE, "..", "..")) @@ -99,6 +97,7 @@ def _transform_changelog(path_in: str, path_out: str) -> None: "sphinx_copybutton", "sphinx_paramlinks", "sphinx_togglebutton", + "lai_sphinx_theme.extensions.lightning", ] # Add any paths that contain templates here, relative to this directory. @@ -152,8 +151,8 @@ def _transform_changelog(path_in: str, path_out: str) -> None: # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = "pt_lightning_sphinx_theme" -html_theme_path = [pt_lightning_sphinx_theme.get_html_theme_path()] +html_theme = "lai_sphinx_theme" +html_theme_path = [lai_sphinx_theme.get_html_theme_path()] # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the diff --git a/notebooks/dev_tutorials/fsdp_tutorial.ipynb b/notebooks/dev_tutorials/fsdp_tutorial.ipynb index b8cef2a2ff..41c4dee406 100644 --- a/notebooks/dev_tutorials/fsdp_tutorial.ipynb +++ b/notebooks/dev_tutorials/fsdp_tutorial.ipynb @@ -45,7 +45,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -58,7 +58,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -85,7 +85,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -102,165 +102,9 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
# Constructed by Dead Code Elimination (took 0 milliseconds)\n",
-       "import torch\n",
-       "import torch.nn.functional\n",
-       "from thunder.executors.torchex import no_autocast\n",
-       "\n",
-       "@torch.no_grad()\n",
-       "@no_autocast()\n",
-       "def augmented_forward_fn(*args):\n",
-       "  # args: "Collection" \n",
-       "  t0, \\\n",
-       "  t1, \\\n",
-       "  t2, \\\n",
-       "  = args\n",
-       "  t3 = torch.nn.functional.linear(t0, t1, None)  # t3: "cuda:0 f32[64, 64]"\n",
-       "    # t3 = ltorch.linear(t0, t1, None)  # t3: "cuda:0 f32[64, 64]"\n",
-       "      # t3 = prims.linear(t0, t1, None)  # t3: "cuda:0 f32[64, 64]"\n",
-       "  [t4] = nvFusion0(t3)\n",
-       "    # t4 = prims.tanh(t3)  # t4: "cuda:0 f32[64, 64]"\n",
-       "  t5 = torch.nn.functional.linear(t4, t2, None)  # t5: "cuda:0 f32[64, 64]"\n",
-       "    # t5 = ltorch.linear(t4, t2, None)  # t5: "cuda:0 f32[64, 64]"\n",
-       "      # t5 = prims.linear(t4, t2, None)  # t5: "cuda:0 f32[64, 64]"\n",
-       "  return {'output': t5, 'flat_args': [t0, t1, t2], 'flat_output': (t5,)}, ((t0, t2, t4), ())\n",
-       "
\n" - ], - "text/latex": [ - "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", - "\\PY{c+c1}{\\PYZsh{} Constructed by Dead Code Elimination (took 0 milliseconds)}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{torch}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{torch}\\PY{n+nn}{.}\\PY{n+nn}{nn}\\PY{n+nn}{.}\\PY{n+nn}{functional}\n", - "\\PY{k+kn}{from} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{executors}\\PY{n+nn}{.}\\PY{n+nn}{torchex} \\PY{k+kn}{import} \\PY{n}{no\\PYZus{}autocast}\n", - "\n", - "\\PY{n+nd}{@torch}\\PY{o}{.}\\PY{n}{no\\PYZus{}grad}\\PY{p}{(}\\PY{p}{)}\n", - "\\PY{n+nd}{@no\\PYZus{}autocast}\\PY{p}{(}\\PY{p}{)}\n", - "\\PY{k}{def} \\PY{n+nf}{augmented\\PYZus{}forward\\PYZus{}fn}\\PY{p}{(}\\PY{o}{*}\\PY{n}{args}\\PY{p}{)}\\PY{p}{:}\n", - " \\PY{c+c1}{\\PYZsh{} args: \\PYZdq{}Collection\\PYZdq{} }\n", - " \\PY{n}{t0}\\PY{p}{,} \\PYZbs{}\n", - " \\PY{n}{t1}\\PY{p}{,} \\PYZbs{}\n", - " \\PY{n}{t2}\\PY{p}{,} \\PYZbs{}\n", - " \\PY{o}{=} \\PY{n}{args}\n", - " \\PY{n}{t3} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{nn}\\PY{o}{.}\\PY{n}{functional}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t1}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t3 = ltorch.linear(t0, t1, None) \\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t3 = prims.linear(t0, t1, None) \\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{p}{[}\\PY{n}{t4}\\PY{p}{]} \\PY{o}{=} \\PY{n}{nvFusion0}\\PY{p}{(}\\PY{n}{t3}\\PY{p}{)}\n", - " \\PY{c+c1}{\\PYZsh{} t4 = prims.tanh(t3) \\PYZsh{} t4: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t5} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{nn}\\PY{o}{.}\\PY{n}{functional}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t4}\\PY{p}{,} \\PY{n}{t2}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t5 = ltorch.linear(t4, t2, None) \\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t5 = prims.linear(t4, t2, None) \\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{k}{return} \\PY{p}{\\PYZob{}}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{n}{t5}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}args}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{[}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t1}\\PY{p}{,} \\PY{n}{t2}\\PY{p}{]}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{(}\\PY{n}{t5}\\PY{p}{,}\\PY{p}{)}\\PY{p}{\\PYZcb{}}\\PY{p}{,} \\PY{p}{(}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t2}\\PY{p}{,} \\PY{n}{t4}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\n", - "\\end{Verbatim}\n" - ], - "text/plain": [ - "# Constructed by Dead Code Elimination (took 0 milliseconds)\n", - "import torch\n", - "import torch.nn.functional\n", - "from thunder.executors.torchex import no_autocast\n", - "\n", - "@torch.no_grad()\n", - "@no_autocast()\n", - "def augmented_forward_fn(*args):\n", - " # args: \"Collection\" \n", - " t0, \\\n", - " t1, \\\n", - " t2, \\\n", - " = args\n", - " t3 = torch.nn.functional.linear(t0, t1, None) # t3: \"cuda:0 f32[64, 64]\"\n", - " # t3 = ltorch.linear(t0, t1, None) # t3: \"cuda:0 f32[64, 64]\"\n", - " # t3 = prims.linear(t0, t1, None) # t3: \"cuda:0 f32[64, 64]\"\n", - " [t4] = nvFusion0(t3)\n", - " # t4 = prims.tanh(t3) # t4: \"cuda:0 f32[64, 64]\"\n", - " t5 = torch.nn.functional.linear(t4, t2, None) # t5: \"cuda:0 f32[64, 64]\"\n", - " # t5 = ltorch.linear(t4, t2, None) # t5: \"cuda:0 f32[64, 64]\"\n", - " # t5 = prims.linear(t4, t2, None) # t5: \"cuda:0 f32[64, 64]\"\n", - " return {'output': t5, 'flat_args': [t0, t1, t2], 'flat_output': (t5,)}, ((t0, t2, t4), ())" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "wrap_as_highlighted_code(computation_trace)" ] @@ -276,11 +120,11 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "# FSDP Config \n", + "# FSDP Config\n", "# Usually these values are set in the environment by `torchrun` but for this example\n", "# we will set them ourselves\n", "world_size = 2 # We have two processes.\n", @@ -298,7 +142,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -326,24 +170,9 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Sequential(\n", - " (0): Linear(in_features=64, out_features=64, bias=False)\n", - " (1): Tanh()\n", - " (2): Linear(in_features=64, out_features=64, bias=False)\n", - ")" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "# Verify our model looks as expected\n", "model" @@ -351,7 +180,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -376,7 +205,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -387,195 +216,9 @@ "torch.distributed.distributed_c10d.GroupMember.WORLD = process_group" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\n", - "Because we are trying to play tricks with the traces and skip the part that inserts the synchronization automatically but also does the translation from PyTorch to thunder, we need to drop one layer of the trace to apply this manually.\n", - "(This is really hacky, don't try it at home!)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
# Constructed by Dead Code Elimination (took 0 milliseconds)\n",
-       "import thunder\n",
-       "import thunder.core.prims as prims\n",
-       "import thunder.torch as ltorch\n",
-       "import torch\n",
-       "import torch.nn.functional\n",
-       "from thunder.executors.torchex import no_autocast\n",
-       "\n",
-       "@torch.no_grad()\n",
-       "@no_autocast()\n",
-       "def augmented_forward_fn(*args):\n",
-       "  # args: "Collection" \n",
-       "  t0, \\\n",
-       "  t1, \\\n",
-       "  t2, \\\n",
-       "  = args\n",
-       "  t3 = ltorch.linear(t0, t1, None)  # t3: "cuda:0 f32[64, 64]"\n",
-       "    # t3 = ltorch.linear(t0, t1, None)  # t3: "cuda:0 f32[64, 64]"\n",
-       "      # t3 = prims.linear(t0, t1, None)  # t3: "cuda:0 f32[64, 64]"\n",
-       "  t4 = prims.tanh(t3)  # t4: "cuda:0 f32[64, 64]"\n",
-       "  t5 = ltorch.linear(t4, t2, None)  # t5: "cuda:0 f32[64, 64]"\n",
-       "    # t5 = ltorch.linear(t4, t2, None)  # t5: "cuda:0 f32[64, 64]"\n",
-       "      # t5 = prims.linear(t4, t2, None)  # t5: "cuda:0 f32[64, 64]"\n",
-       "  return {'output': t5, 'flat_args': [t0, t1, t2], 'flat_output': (t5,)}, ((t0, t2, t4), ())\n",
-       "
\n" - ], - "text/latex": [ - "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", - "\\PY{c+c1}{\\PYZsh{} Constructed by Dead Code Elimination (took 0 milliseconds)}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{core}\\PY{n+nn}{.}\\PY{n+nn}{prims} \\PY{k}{as} \\PY{n+nn}{prims}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{torch} \\PY{k}{as} \\PY{n+nn}{ltorch}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{torch}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{torch}\\PY{n+nn}{.}\\PY{n+nn}{nn}\\PY{n+nn}{.}\\PY{n+nn}{functional}\n", - "\\PY{k+kn}{from} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{executors}\\PY{n+nn}{.}\\PY{n+nn}{torchex} \\PY{k+kn}{import} \\PY{n}{no\\PYZus{}autocast}\n", - "\n", - "\\PY{n+nd}{@torch}\\PY{o}{.}\\PY{n}{no\\PYZus{}grad}\\PY{p}{(}\\PY{p}{)}\n", - "\\PY{n+nd}{@no\\PYZus{}autocast}\\PY{p}{(}\\PY{p}{)}\n", - "\\PY{k}{def} \\PY{n+nf}{augmented\\PYZus{}forward\\PYZus{}fn}\\PY{p}{(}\\PY{o}{*}\\PY{n}{args}\\PY{p}{)}\\PY{p}{:}\n", - " \\PY{c+c1}{\\PYZsh{} args: \\PYZdq{}Collection\\PYZdq{} }\n", - " \\PY{n}{t0}\\PY{p}{,} \\PYZbs{}\n", - " \\PY{n}{t1}\\PY{p}{,} \\PYZbs{}\n", - " \\PY{n}{t2}\\PY{p}{,} \\PYZbs{}\n", - " \\PY{o}{=} \\PY{n}{args}\n", - " \\PY{n}{t3} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t1}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t3 = ltorch.linear(t0, t1, None) \\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t3 = prims.linear(t0, t1, None) \\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t4} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{tanh}\\PY{p}{(}\\PY{n}{t3}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t4: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t5} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t4}\\PY{p}{,} \\PY{n}{t2}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t5 = ltorch.linear(t4, t2, None) \\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t5 = prims.linear(t4, t2, None) \\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{k}{return} \\PY{p}{\\PYZob{}}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{n}{t5}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}args}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{[}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t1}\\PY{p}{,} \\PY{n}{t2}\\PY{p}{]}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{(}\\PY{n}{t5}\\PY{p}{,}\\PY{p}{)}\\PY{p}{\\PYZcb{}}\\PY{p}{,} \\PY{p}{(}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t2}\\PY{p}{,} \\PY{n}{t4}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\n", - "\\end{Verbatim}\n" - ], - "text/plain": [ - "# Constructed by Dead Code Elimination (took 0 milliseconds)\n", - "import thunder\n", - "import thunder.core.prims as prims\n", - "import thunder.torch as ltorch\n", - "import torch\n", - "import torch.nn.functional\n", - "from thunder.executors.torchex import no_autocast\n", - "\n", - "@torch.no_grad()\n", - "@no_autocast()\n", - "def augmented_forward_fn(*args):\n", - " # args: \"Collection\" \n", - " t0, \\\n", - " t1, \\\n", - " t2, \\\n", - " = args\n", - " t3 = ltorch.linear(t0, t1, None) # t3: \"cuda:0 f32[64, 64]\"\n", - " # t3 = ltorch.linear(t0, t1, None) # t3: \"cuda:0 f32[64, 64]\"\n", - " # t3 = prims.linear(t0, t1, None) # t3: \"cuda:0 f32[64, 64]\"\n", - " t4 = prims.tanh(t3) # t4: \"cuda:0 f32[64, 64]\"\n", - " t5 = ltorch.linear(t4, t2, None) # t5: \"cuda:0 f32[64, 64]\"\n", - " # t5 = ltorch.linear(t4, t2, None) # t5: \"cuda:0 f32[64, 64]\"\n", - " # t5 = prims.linear(t4, t2, None) # t5: \"cuda:0 f32[64, 64]\"\n", - " return {'output': t5, 'flat_args': [t0, t1, t2], 'flat_output': (t5,)}, ((t0, t2, t4), ())" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "### DON'T TRY THIS AT HOME\n", - "computation_trace.bound_symbols[2].sym = cache_rec.computation_traces[0].bound_symbols[2].subsymbols[0].sym\n", - "if cache_rec.computation_traces[0].bound_symbols[3].subsymbols:\n", - " computation_trace.bound_symbols[3] = cache_rec.computation_traces[0].bound_symbols[3].subsymbols[0]\n", - "computation_trace.bound_symbols[4].sym = cache_rec.computation_traces[0].bound_symbols[4].subsymbols[0].sym\n", - "\n", - "wrap_as_highlighted_code(computation_trace)" - ] - }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -611,171 +254,9 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
# Constructed by Dead Code Elimination (took 0 milliseconds)\n",
-       "import thunder\n",
-       "import thunder.core.prims as prims\n",
-       "import thunder.distributed.prims\n",
-       "import thunder.torch as ltorch\n",
-       "import torch\n",
-       "from thunder.executors.torchex import no_autocast\n",
-       "\n",
-       "@torch.no_grad()\n",
-       "@no_autocast()\n",
-       "def model_with_syncs(x, *params):\n",
-       "  # x: "cuda:0 f32[64, 64]" \n",
-       "  # params: "Collection" \n",
-       "  t0, \\\n",
-       "  t1, \\\n",
-       "  = params\n",
-       "  t2 = thunder.distributed.prims.synchronize(t0, _torch_distributed_distributed_c10d_ProcessGroup_0)  # t2: "cuda:0 f32[64, 64]"\n",
-       "  t3 = thunder.distributed.prims.synchronize(t1, _torch_distributed_distributed_c10d_ProcessGroup_0)  # t3: "cuda:0 f32[64, 64]"\n",
-       "  t4 = ltorch.linear(x, t2, None)  # t4: "cuda:0 f32[64, 64]"\n",
-       "    # t4 = prims.linear(x, t2, None)  # t4: "cuda:0 f32[64, 64]"\n",
-       "  t5 = prims.tanh(t4)  # t5: "cuda:0 f32[64, 64]"\n",
-       "  t6 = ltorch.linear(t5, t3, None)  # t6: "cuda:0 f32[64, 64]"\n",
-       "    # t6 = prims.linear(t5, t3, None)  # t6: "cuda:0 f32[64, 64]"\n",
-       "  return ({'output': t6, 'flat_args': [x, t2, t3], 'flat_output': (t6,)}, ((x, t3, t5), ()))\n",
-       "
\n" - ], - "text/latex": [ - "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", - "\\PY{c+c1}{\\PYZsh{} Constructed by Dead Code Elimination (took 0 milliseconds)}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{core}\\PY{n+nn}{.}\\PY{n+nn}{prims} \\PY{k}{as} \\PY{n+nn}{prims}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{distributed}\\PY{n+nn}{.}\\PY{n+nn}{prims}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{torch} \\PY{k}{as} \\PY{n+nn}{ltorch}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{torch}\n", - "\\PY{k+kn}{from} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{executors}\\PY{n+nn}{.}\\PY{n+nn}{torchex} \\PY{k+kn}{import} \\PY{n}{no\\PYZus{}autocast}\n", - "\n", - "\\PY{n+nd}{@torch}\\PY{o}{.}\\PY{n}{no\\PYZus{}grad}\\PY{p}{(}\\PY{p}{)}\n", - "\\PY{n+nd}{@no\\PYZus{}autocast}\\PY{p}{(}\\PY{p}{)}\n", - "\\PY{k}{def} \\PY{n+nf}{model\\PYZus{}with\\PYZus{}syncs}\\PY{p}{(}\\PY{n}{x}\\PY{p}{,} \\PY{o}{*}\\PY{n}{params}\\PY{p}{)}\\PY{p}{:}\n", - " \\PY{c+c1}{\\PYZsh{} x: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{} }\n", - " \\PY{c+c1}{\\PYZsh{} params: \\PYZdq{}Collection\\PYZdq{} }\n", - " \\PY{n}{t0}\\PY{p}{,} \\PYZbs{}\n", - " \\PY{n}{t1}\\PY{p}{,} \\PYZbs{}\n", - " \\PY{o}{=} \\PY{n}{params}\n", - " \\PY{n}{t2} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{synchronize}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t2: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t3} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{synchronize}\\PY{p}{(}\\PY{n}{t1}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t4} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{x}\\PY{p}{,} \\PY{n}{t2}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t4: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t4 = prims.linear(x, t2, None) \\PYZsh{} t4: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t5} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{tanh}\\PY{p}{(}\\PY{n}{t4}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t6} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t5}\\PY{p}{,} \\PY{n}{t3}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t6: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t6 = prims.linear(t5, t3, None) \\PYZsh{} t6: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{k}{return} \\PY{p}{(}\\PY{p}{\\PYZob{}}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{n}{t6}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}args}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{[}\\PY{n}{x}\\PY{p}{,} \\PY{n}{t2}\\PY{p}{,} \\PY{n}{t3}\\PY{p}{]}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{(}\\PY{n}{t6}\\PY{p}{,}\\PY{p}{)}\\PY{p}{\\PYZcb{}}\\PY{p}{,} \\PY{p}{(}\\PY{p}{(}\\PY{n}{x}\\PY{p}{,} \\PY{n}{t3}\\PY{p}{,} \\PY{n}{t5}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", - "\\end{Verbatim}\n" - ], - "text/plain": [ - "# Constructed by Dead Code Elimination (took 0 milliseconds)\n", - "import thunder\n", - "import thunder.core.prims as prims\n", - "import thunder.distributed.prims\n", - "import thunder.torch as ltorch\n", - "import torch\n", - "from thunder.executors.torchex import no_autocast\n", - "\n", - "@torch.no_grad()\n", - "@no_autocast()\n", - "def model_with_syncs(x, *params):\n", - " # x: \"cuda:0 f32[64, 64]\" \n", - " # params: \"Collection\" \n", - " t0, \\\n", - " t1, \\\n", - " = params\n", - " t2 = thunder.distributed.prims.synchronize(t0, _torch_distributed_distributed_c10d_ProcessGroup_0) # t2: \"cuda:0 f32[64, 64]\"\n", - " t3 = thunder.distributed.prims.synchronize(t1, _torch_distributed_distributed_c10d_ProcessGroup_0) # t3: \"cuda:0 f32[64, 64]\"\n", - " t4 = ltorch.linear(x, t2, None) # t4: \"cuda:0 f32[64, 64]\"\n", - " # t4 = prims.linear(x, t2, None) # t4: \"cuda:0 f32[64, 64]\"\n", - " t5 = prims.tanh(t4) # t5: \"cuda:0 f32[64, 64]\"\n", - " t6 = ltorch.linear(t5, t3, None) # t6: \"cuda:0 f32[64, 64]\"\n", - " # t6 = prims.linear(t5, t3, None) # t6: \"cuda:0 f32[64, 64]\"\n", - " return ({'output': t6, 'flat_args': [x, t2, t3], 'flat_output': (t6,)}, ((x, t3, t5), ()))" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "trace = thunder.trace()(model_with_syncs, x, *model.parameters())\n", "\n", @@ -795,339 +276,9 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
# Constructed by Dead Code Elimination (took 1 milliseconds)\n",
-       "import thunder\n",
-       "import thunder.core.devices as devices\n",
-       "import thunder.core.dtypes as dtypes\n",
-       "import thunder.core.prims as prims\n",
-       "import thunder.distributed.prims\n",
-       "import thunder.torch as ltorch\n",
-       "import torch\n",
-       "from thunder.executors.torchex import no_autocast\n",
-       "\n",
-       "@torch.no_grad()\n",
-       "@no_autocast()\n",
-       "def _value_and_grad(*args):\n",
-       "  # args: "Collection" \n",
-       "  t0, \\\n",
-       "  t1, \\\n",
-       "  t2, \\\n",
-       "  = args\n",
-       "  t3 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t3: "cuda:0 f32[64, 64]"\n",
-       "  t4 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t4: "cuda:0 f32[64, 64]"\n",
-       "  t5 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t5: "cuda:0 f32[64, 64]"\n",
-       "  t6 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t6: "cuda:0 f32[64, 64]"\n",
-       "  t7 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t7: "cuda:0 f32[64, 64]"\n",
-       "  t8 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t8: "cuda:0 f32[64, 64]"\n",
-       "  t9 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t9: "cuda:0 f32[64, 64]"\n",
-       "  t10 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t10: "cuda:0 f32[64, 64]"\n",
-       "  p11 = thunder.distributed.prims.all_gather(t1, _torch_distributed_distributed_c10d_ProcessGroup_0, True)  # p11: "FUTURE cuda:0 f32[64, 64]"\n",
-       "  t12 = thunder.distributed.prims.wait(p11)  # t12: "cuda:0 f32[64, 64]"\n",
-       "  p13 = thunder.distributed.prims.all_gather(t2, _torch_distributed_distributed_c10d_ProcessGroup_0, True)  # p13: "FUTURE cuda:0 f32[64, 64]"\n",
-       "  t14 = thunder.distributed.prims.wait(p13)  # t14: "cuda:0 f32[64, 64]"\n",
-       "  t15 = prims.linear(t0, t12, None)  # t15: "cuda:0 f32[64, 64]"\n",
-       "  t16 = prims.tanh(t15)  # t16: "cuda:0 f32[64, 64]"\n",
-       "  t17 = prims.linear(t16, t14, None)  # t17: "cuda:0 f32[64, 64]"\n",
-       "  t18 = prims.add(t6, t7)  # t18: "cuda:0 f32[64, 64]"\n",
-       "  t19 = prims.add(t3, t8)  # t19: "cuda:0 f32[64, 64]"\n",
-       "  t20 = prims.add(t5, t9)  # t20: "cuda:0 f32[64, 64]"\n",
-       "  t21 = ltorch.reshape(t18, -1, 64)  # t21: "cuda:0 f32[64, 64]"\n",
-       "    # t21 = prims.reshape(t18, (64, 64))  # t21: "cuda:0 f32[64, 64]"\n",
-       "  t22 = ltorch.matmul(t21, t14)  # t22: "cuda:0 f32[64, 64]"\n",
-       "    # t22 = prims.matmul(t21, t14)  # t22: "cuda:0 f32[64, 64]"\n",
-       "  t23 = ltorch.reshape(t18, -1, 64)  # t23: "cuda:0 f32[64, 64]"\n",
-       "    # t23 = prims.reshape(t18, (64, 64))  # t23: "cuda:0 f32[64, 64]"\n",
-       "  t24 = prims.transpose(t23, (1, 0))  # t24: "cuda:0 f32[64, 64]"\n",
-       "  t25 = ltorch.reshape(t16, -1, 64)  # t25: "cuda:0 f32[64, 64]"\n",
-       "    # t25 = prims.reshape(t16, (64, 64))  # t25: "cuda:0 f32[64, 64]"\n",
-       "  t26 = ltorch.matmul(t24, t25)  # t26: "cuda:0 f32[64, 64]"\n",
-       "    # t26 = prims.matmul(t24, t25)  # t26: "cuda:0 f32[64, 64]"\n",
-       "  t27 = prims.add(t10, t22)  # t27: "cuda:0 f32[64, 64]"\n",
-       "  t28 = prims.add(t20, t26)  # t28: "cuda:0 f32[64, 64]"\n",
-       "  t29 = ltorch.mul(t16, t16)  # t29: "cuda:0 f32[64, 64]"\n",
-       "    # t29 = prims.mul(t16, t16)  # t29: "cuda:0 f32[64, 64]"\n",
-       "  t30 = ltorch.sub(1, t29, alpha=None)  # t30: "cuda:0 f32[64, 64]"\n",
-       "    # _ = prims.convert_element_type(1, float)\n",
-       "    # t30 = prims.sub(1.0, t29)  # t30: "cuda:0 f32[64, 64]"\n",
-       "  t31 = ltorch.mul(t27, t30)  # t31: "cuda:0 f32[64, 64]"\n",
-       "    # t31 = prims.mul(t27, t30)  # t31: "cuda:0 f32[64, 64]"\n",
-       "  t32 = ltorch.reshape(t31, -1, 64)  # t32: "cuda:0 f32[64, 64]"\n",
-       "    # t32 = prims.reshape(t31, (64, 64))  # t32: "cuda:0 f32[64, 64]"\n",
-       "  t33 = ltorch.matmul(t32, t12)  # t33: "cuda:0 f32[64, 64]"\n",
-       "    # t33 = prims.matmul(t32, t12)  # t33: "cuda:0 f32[64, 64]"\n",
-       "  t34 = ltorch.reshape(t31, -1, 64)  # t34: "cuda:0 f32[64, 64]"\n",
-       "    # t34 = prims.reshape(t31, (64, 64))  # t34: "cuda:0 f32[64, 64]"\n",
-       "  t35 = prims.transpose(t34, (1, 0))  # t35: "cuda:0 f32[64, 64]"\n",
-       "  t36 = ltorch.reshape(t0, -1, 64)  # t36: "cuda:0 f32[64, 64]"\n",
-       "    # t36 = prims.reshape(t0, (64, 64))  # t36: "cuda:0 f32[64, 64]"\n",
-       "  t37 = ltorch.matmul(t35, t36)  # t37: "cuda:0 f32[64, 64]"\n",
-       "    # t37 = prims.matmul(t35, t36)  # t37: "cuda:0 f32[64, 64]"\n",
-       "  t38 = prims.add(t19, t33)  # t38: "cuda:0 f32[64, 64]"\n",
-       "  t39 = prims.add(t4, t37)  # t39: "cuda:0 f32[64, 64]"\n",
-       "  t40 = ltorch.true_divide(t28, 2)  # t40: "cuda:0 f32[64, 64]"\n",
-       "    # _ = prims.convert_element_type(2, float)\n",
-       "    # t40 = prims.div(t28, 2.0)  # t40: "cuda:0 f32[64, 64]"\n",
-       "  p41 = thunder.distributed.prims.reduce_scatter(t40, _DistributedReduceOps_1, _torch_distributed_distributed_c10d_ProcessGroup_0, True)  # p41: "FUTURE cuda:0 f32[32, 64]"\n",
-       "  t42 = thunder.distributed.prims.wait(p41)  # t42: "cuda:0 f32[32, 64]"\n",
-       "  t43 = ltorch.true_divide(t39, 2)  # t43: "cuda:0 f32[64, 64]"\n",
-       "    # _ = prims.convert_element_type(2, float)\n",
-       "    # t43 = prims.div(t39, 2.0)  # t43: "cuda:0 f32[64, 64]"\n",
-       "  p44 = thunder.distributed.prims.reduce_scatter(t43, _DistributedReduceOps_1, _torch_distributed_distributed_c10d_ProcessGroup_0, True)  # p44: "FUTURE cuda:0 f32[32, 64]"\n",
-       "  t45 = thunder.distributed.prims.wait(p44)  # t45: "cuda:0 f32[32, 64]"\n",
-       "  return (({'output': t17, 'flat_args': [t0, t12, t14], 'flat_output': (t17,)}, ((t0, t14, t16), ())), (t38, t45, t42))\n",
-       "
\n" - ], - "text/latex": [ - "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", - "\\PY{c+c1}{\\PYZsh{} Constructed by Dead Code Elimination (took 1 milliseconds)}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{core}\\PY{n+nn}{.}\\PY{n+nn}{devices} \\PY{k}{as} \\PY{n+nn}{devices}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{core}\\PY{n+nn}{.}\\PY{n+nn}{dtypes} \\PY{k}{as} \\PY{n+nn}{dtypes}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{core}\\PY{n+nn}{.}\\PY{n+nn}{prims} \\PY{k}{as} \\PY{n+nn}{prims}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{distributed}\\PY{n+nn}{.}\\PY{n+nn}{prims}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{torch} \\PY{k}{as} \\PY{n+nn}{ltorch}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{torch}\n", - "\\PY{k+kn}{from} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{executors}\\PY{n+nn}{.}\\PY{n+nn}{torchex} \\PY{k+kn}{import} \\PY{n}{no\\PYZus{}autocast}\n", - "\n", - "\\PY{n+nd}{@torch}\\PY{o}{.}\\PY{n}{no\\PYZus{}grad}\\PY{p}{(}\\PY{p}{)}\n", - "\\PY{n+nd}{@no\\PYZus{}autocast}\\PY{p}{(}\\PY{p}{)}\n", - "\\PY{k}{def} \\PY{n+nf}{\\PYZus{}value\\PYZus{}and\\PYZus{}grad}\\PY{p}{(}\\PY{o}{*}\\PY{n}{args}\\PY{p}{)}\\PY{p}{:}\n", - " \\PY{c+c1}{\\PYZsh{} args: \\PYZdq{}Collection\\PYZdq{} }\n", - " \\PY{n}{t0}\\PY{p}{,} \\PYZbs{}\n", - " \\PY{n}{t1}\\PY{p}{,} \\PYZbs{}\n", - " \\PY{n}{t2}\\PY{p}{,} \\PYZbs{}\n", - " \\PY{o}{=} \\PY{n}{args}\n", - " \\PY{n}{t3} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t4} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t4: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t5} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t6} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t6: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t7} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t7: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t8} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t8: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t9} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t9: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t10} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t10: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{p11} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{all\\PYZus{}gather}\\PY{p}{(}\\PY{n}{t1}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p11: \\PYZdq{}FUTURE cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t12} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{wait}\\PY{p}{(}\\PY{n}{p11}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t12: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{p13} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{all\\PYZus{}gather}\\PY{p}{(}\\PY{n}{t2}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p13: \\PYZdq{}FUTURE cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t14} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{wait}\\PY{p}{(}\\PY{n}{p13}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t14: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t15} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t12}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t15: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t16} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{tanh}\\PY{p}{(}\\PY{n}{t15}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t16: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t17} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t16}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t17: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t18} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t6}\\PY{p}{,} \\PY{n}{t7}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t18: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t19} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t3}\\PY{p}{,} \\PY{n}{t8}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t19: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t20} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t5}\\PY{p}{,} \\PY{n}{t9}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t20: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t21} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t18}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t21: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t21 = prims.reshape(t18, (64, 64)) \\PYZsh{} t21: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t22} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t21}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t22: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t22 = prims.matmul(t21, t14) \\PYZsh{} t22: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t23} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t18}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t23: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t23 = prims.reshape(t18, (64, 64)) \\PYZsh{} t23: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t24} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{transpose}\\PY{p}{(}\\PY{n}{t23}\\PY{p}{,} \\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t24: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t25} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t16}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t25: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t25 = prims.reshape(t16, (64, 64)) \\PYZsh{} t25: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t26} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t24}\\PY{p}{,} \\PY{n}{t25}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t26: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t26 = prims.matmul(t24, t25) \\PYZsh{} t26: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t27} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t10}\\PY{p}{,} \\PY{n}{t22}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t27: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t28} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t20}\\PY{p}{,} \\PY{n}{t26}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t28: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t29} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{mul}\\PY{p}{(}\\PY{n}{t16}\\PY{p}{,} \\PY{n}{t16}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t29: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t29 = prims.mul(t16, t16) \\PYZsh{} t29: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t30} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{sub}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{t29}\\PY{p}{,} \\PY{n}{alpha}\\PY{o}{=}\\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t30: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} \\PYZus{} = prims.convert\\PYZus{}element\\PYZus{}type(1, float)}\n", - " \\PY{c+c1}{\\PYZsh{} t30 = prims.sub(1.0, t29) \\PYZsh{} t30: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t31} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{mul}\\PY{p}{(}\\PY{n}{t27}\\PY{p}{,} \\PY{n}{t30}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t31: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t31 = prims.mul(t27, t30) \\PYZsh{} t31: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t32} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t31}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t32: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t32 = prims.reshape(t31, (64, 64)) \\PYZsh{} t32: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t33} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t32}\\PY{p}{,} \\PY{n}{t12}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t33: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t33 = prims.matmul(t32, t12) \\PYZsh{} t33: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t34} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t31}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t34: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t34 = prims.reshape(t31, (64, 64)) \\PYZsh{} t34: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t35} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{transpose}\\PY{p}{(}\\PY{n}{t34}\\PY{p}{,} \\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t35: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t36} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t36: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t36 = prims.reshape(t0, (64, 64)) \\PYZsh{} t36: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t37} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t35}\\PY{p}{,} \\PY{n}{t36}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t37: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t37 = prims.matmul(t35, t36) \\PYZsh{} t37: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t38} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t19}\\PY{p}{,} \\PY{n}{t33}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t38: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t39} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t4}\\PY{p}{,} \\PY{n}{t37}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t39: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t40} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{true\\PYZus{}divide}\\PY{p}{(}\\PY{n}{t28}\\PY{p}{,} \\PY{l+m+mi}{2}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t40: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} \\PYZus{} = prims.convert\\PYZus{}element\\PYZus{}type(2, float)}\n", - " \\PY{c+c1}{\\PYZsh{} t40 = prims.div(t28, 2.0) \\PYZsh{} t40: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{p41} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{reduce\\PYZus{}scatter}\\PY{p}{(}\\PY{n}{t40}\\PY{p}{,} \\PY{n}{\\PYZus{}DistributedReduceOps\\PYZus{}1}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p41: \\PYZdq{}FUTURE cuda:0 f32[32, 64]\\PYZdq{}}\n", - " \\PY{n}{t42} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{wait}\\PY{p}{(}\\PY{n}{p41}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t42: \\PYZdq{}cuda:0 f32[32, 64]\\PYZdq{}}\n", - " \\PY{n}{t43} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{true\\PYZus{}divide}\\PY{p}{(}\\PY{n}{t39}\\PY{p}{,} \\PY{l+m+mi}{2}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t43: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} \\PYZus{} = prims.convert\\PYZus{}element\\PYZus{}type(2, float)}\n", - " \\PY{c+c1}{\\PYZsh{} t43 = prims.div(t39, 2.0) \\PYZsh{} t43: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{p44} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{reduce\\PYZus{}scatter}\\PY{p}{(}\\PY{n}{t43}\\PY{p}{,} \\PY{n}{\\PYZus{}DistributedReduceOps\\PYZus{}1}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p44: \\PYZdq{}FUTURE cuda:0 f32[32, 64]\\PYZdq{}}\n", - " \\PY{n}{t45} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{wait}\\PY{p}{(}\\PY{n}{p44}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t45: \\PYZdq{}cuda:0 f32[32, 64]\\PYZdq{}}\n", - " \\PY{k}{return} \\PY{p}{(}\\PY{p}{(}\\PY{p}{\\PYZob{}}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{n}{t17}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}args}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{[}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t12}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{]}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{(}\\PY{n}{t17}\\PY{p}{,}\\PY{p}{)}\\PY{p}{\\PYZcb{}}\\PY{p}{,} \\PY{p}{(}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{,} \\PY{n}{t16}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{n}{t38}\\PY{p}{,} \\PY{n}{t45}\\PY{p}{,} \\PY{n}{t42}\\PY{p}{)}\\PY{p}{)}\n", - "\\end{Verbatim}\n" - ], - "text/plain": [ - "# Constructed by Dead Code Elimination (took 1 milliseconds)\n", - "import thunder\n", - "import thunder.core.devices as devices\n", - "import thunder.core.dtypes as dtypes\n", - "import thunder.core.prims as prims\n", - "import thunder.distributed.prims\n", - "import thunder.torch as ltorch\n", - "import torch\n", - "from thunder.executors.torchex import no_autocast\n", - "\n", - "@torch.no_grad()\n", - "@no_autocast()\n", - "def _value_and_grad(*args):\n", - " # args: \"Collection\" \n", - " t0, \\\n", - " t1, \\\n", - " t2, \\\n", - " = args\n", - " t3 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t3: \"cuda:0 f32[64, 64]\"\n", - " t4 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t4: \"cuda:0 f32[64, 64]\"\n", - " t5 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t5: \"cuda:0 f32[64, 64]\"\n", - " t6 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t6: \"cuda:0 f32[64, 64]\"\n", - " t7 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t7: \"cuda:0 f32[64, 64]\"\n", - " t8 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t8: \"cuda:0 f32[64, 64]\"\n", - " t9 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t9: \"cuda:0 f32[64, 64]\"\n", - " t10 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t10: \"cuda:0 f32[64, 64]\"\n", - " p11 = thunder.distributed.prims.all_gather(t1, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p11: \"FUTURE cuda:0 f32[64, 64]\"\n", - " t12 = thunder.distributed.prims.wait(p11) # t12: \"cuda:0 f32[64, 64]\"\n", - " p13 = thunder.distributed.prims.all_gather(t2, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p13: \"FUTURE cuda:0 f32[64, 64]\"\n", - " t14 = thunder.distributed.prims.wait(p13) # t14: \"cuda:0 f32[64, 64]\"\n", - " t15 = prims.linear(t0, t12, None) # t15: \"cuda:0 f32[64, 64]\"\n", - " t16 = prims.tanh(t15) # t16: \"cuda:0 f32[64, 64]\"\n", - " t17 = prims.linear(t16, t14, None) # t17: \"cuda:0 f32[64, 64]\"\n", - " t18 = prims.add(t6, t7) # t18: \"cuda:0 f32[64, 64]\"\n", - " t19 = prims.add(t3, t8) # t19: \"cuda:0 f32[64, 64]\"\n", - " t20 = prims.add(t5, t9) # t20: \"cuda:0 f32[64, 64]\"\n", - " t21 = ltorch.reshape(t18, -1, 64) # t21: \"cuda:0 f32[64, 64]\"\n", - " # t21 = prims.reshape(t18, (64, 64)) # t21: \"cuda:0 f32[64, 64]\"\n", - " t22 = ltorch.matmul(t21, t14) # t22: \"cuda:0 f32[64, 64]\"\n", - " # t22 = prims.matmul(t21, t14) # t22: \"cuda:0 f32[64, 64]\"\n", - " t23 = ltorch.reshape(t18, -1, 64) # t23: \"cuda:0 f32[64, 64]\"\n", - " # t23 = prims.reshape(t18, (64, 64)) # t23: \"cuda:0 f32[64, 64]\"\n", - " t24 = prims.transpose(t23, (1, 0)) # t24: \"cuda:0 f32[64, 64]\"\n", - " t25 = ltorch.reshape(t16, -1, 64) # t25: \"cuda:0 f32[64, 64]\"\n", - " # t25 = prims.reshape(t16, (64, 64)) # t25: \"cuda:0 f32[64, 64]\"\n", - " t26 = ltorch.matmul(t24, t25) # t26: \"cuda:0 f32[64, 64]\"\n", - " # t26 = prims.matmul(t24, t25) # t26: \"cuda:0 f32[64, 64]\"\n", - " t27 = prims.add(t10, t22) # t27: \"cuda:0 f32[64, 64]\"\n", - " t28 = prims.add(t20, t26) # t28: \"cuda:0 f32[64, 64]\"\n", - " t29 = ltorch.mul(t16, t16) # t29: \"cuda:0 f32[64, 64]\"\n", - " # t29 = prims.mul(t16, t16) # t29: \"cuda:0 f32[64, 64]\"\n", - " t30 = ltorch.sub(1, t29, alpha=None) # t30: \"cuda:0 f32[64, 64]\"\n", - " # _ = prims.convert_element_type(1, float)\n", - " # t30 = prims.sub(1.0, t29) # t30: \"cuda:0 f32[64, 64]\"\n", - " t31 = ltorch.mul(t27, t30) # t31: \"cuda:0 f32[64, 64]\"\n", - " # t31 = prims.mul(t27, t30) # t31: \"cuda:0 f32[64, 64]\"\n", - " t32 = ltorch.reshape(t31, -1, 64) # t32: \"cuda:0 f32[64, 64]\"\n", - " # t32 = prims.reshape(t31, (64, 64)) # t32: \"cuda:0 f32[64, 64]\"\n", - " t33 = ltorch.matmul(t32, t12) # t33: \"cuda:0 f32[64, 64]\"\n", - " # t33 = prims.matmul(t32, t12) # t33: \"cuda:0 f32[64, 64]\"\n", - " t34 = ltorch.reshape(t31, -1, 64) # t34: \"cuda:0 f32[64, 64]\"\n", - " # t34 = prims.reshape(t31, (64, 64)) # t34: \"cuda:0 f32[64, 64]\"\n", - " t35 = prims.transpose(t34, (1, 0)) # t35: \"cuda:0 f32[64, 64]\"\n", - " t36 = ltorch.reshape(t0, -1, 64) # t36: \"cuda:0 f32[64, 64]\"\n", - " # t36 = prims.reshape(t0, (64, 64)) # t36: \"cuda:0 f32[64, 64]\"\n", - " t37 = ltorch.matmul(t35, t36) # t37: \"cuda:0 f32[64, 64]\"\n", - " # t37 = prims.matmul(t35, t36) # t37: \"cuda:0 f32[64, 64]\"\n", - " t38 = prims.add(t19, t33) # t38: \"cuda:0 f32[64, 64]\"\n", - " t39 = prims.add(t4, t37) # t39: \"cuda:0 f32[64, 64]\"\n", - " t40 = ltorch.true_divide(t28, 2) # t40: \"cuda:0 f32[64, 64]\"\n", - " # _ = prims.convert_element_type(2, float)\n", - " # t40 = prims.div(t28, 2.0) # t40: \"cuda:0 f32[64, 64]\"\n", - " p41 = thunder.distributed.prims.reduce_scatter(t40, _DistributedReduceOps_1, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p41: \"FUTURE cuda:0 f32[32, 64]\"\n", - " t42 = thunder.distributed.prims.wait(p41) # t42: \"cuda:0 f32[32, 64]\"\n", - " t43 = ltorch.true_divide(t39, 2) # t43: \"cuda:0 f32[64, 64]\"\n", - " # _ = prims.convert_element_type(2, float)\n", - " # t43 = prims.div(t39, 2.0) # t43: \"cuda:0 f32[64, 64]\"\n", - " p44 = thunder.distributed.prims.reduce_scatter(t43, _DistributedReduceOps_1, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p44: \"FUTURE cuda:0 f32[32, 64]\"\n", - " t45 = thunder.distributed.prims.wait(p44) # t45: \"cuda:0 f32[32, 64]\"\n", - " return (({'output': t17, 'flat_args': [t0, t12, t14], 'flat_output': (t17,)}, ((t0, t14, t16), ())), (t38, t45, t42))" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "from thunder.core.transforms import value_and_grad\n", "\n", @@ -1151,576 +302,9 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
# Constructed by Delete Last Used (took 0 milliseconds)\n",
-       "import torch\n",
-       "import torch.nn.functional\n",
-       "from thunder.executors.torchex import no_autocast\n",
-       "\n",
-       "@torch.no_grad()\n",
-       "@no_autocast()\n",
-       "def _value_and_grad(*args):\n",
-       "  # args: "Collection" \n",
-       "  t0, \\\n",
-       "  t1, \\\n",
-       "  t2, \\\n",
-       "  = args\n",
-       "  del args\n",
-       "  t3 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t3: "cuda:0 f32[64, 64]"\n",
-       "    # t3 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t3: "cuda:0 f32[64, 64]"\n",
-       "      # t3 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t3: "cuda:0 f32[64, 64]"\n",
-       "  t4 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t4: "cuda:0 f32[64, 64]"\n",
-       "    # t4 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t4: "cuda:0 f32[64, 64]"\n",
-       "      # t4 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t4: "cuda:0 f32[64, 64]"\n",
-       "  t5 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t5: "cuda:0 f32[64, 64]"\n",
-       "    # t5 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t5: "cuda:0 f32[64, 64]"\n",
-       "      # t5 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t5: "cuda:0 f32[64, 64]"\n",
-       "  t6 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t6: "cuda:0 f32[64, 64]"\n",
-       "    # t6 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t6: "cuda:0 f32[64, 64]"\n",
-       "      # t6 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t6: "cuda:0 f32[64, 64]"\n",
-       "  t7 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t7: "cuda:0 f32[64, 64]"\n",
-       "    # t7 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t7: "cuda:0 f32[64, 64]"\n",
-       "      # t7 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t7: "cuda:0 f32[64, 64]"\n",
-       "  t8 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t8: "cuda:0 f32[64, 64]"\n",
-       "    # t8 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t8: "cuda:0 f32[64, 64]"\n",
-       "      # t8 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t8: "cuda:0 f32[64, 64]"\n",
-       "  t9 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t9: "cuda:0 f32[64, 64]"\n",
-       "    # t9 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t9: "cuda:0 f32[64, 64]"\n",
-       "      # t9 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t9: "cuda:0 f32[64, 64]"\n",
-       "  t10 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t10: "cuda:0 f32[64, 64]"\n",
-       "    # t10 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32)  # t10: "cuda:0 f32[64, 64]"\n",
-       "      # t10 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32)  # t10: "cuda:0 f32[64, 64]"\n",
-       "  p11 = torch_all_gather_prim_impl(t1, _torch_distributed_distributed_c10d_ProcessGroup_2, True)  # p11: "FUTURE cuda:0 f32[64, 64]"\n",
-       "  del t1\n",
-       "  t12 = torch_wait_prim_impl(p11)  # t12: "cuda:0 f32[64, 64]"\n",
-       "  del p11\n",
-       "  p13 = torch_all_gather_prim_impl(t2, _torch_distributed_distributed_c10d_ProcessGroup_2, True)  # p13: "FUTURE cuda:0 f32[64, 64]"\n",
-       "  del t2\n",
-       "  t14 = torch_wait_prim_impl(p13)  # t14: "cuda:0 f32[64, 64]"\n",
-       "  del p13\n",
-       "  t15 = torch.nn.functional.linear(t0, t12, None)  # t15: "cuda:0 f32[64, 64]"\n",
-       "    # t15 = ltorch.linear(t0, t12, None)  # t15: "cuda:0 f32[64, 64]"\n",
-       "      # t15 = prims.linear(t0, t12, None)  # t15: "cuda:0 f32[64, 64]"\n",
-       "  t16 = torch.tanh(t15)  # t16: "cuda:0 f32[64, 64]"\n",
-       "    # t16 = ltorch.tanh(t15)  # t16: "cuda:0 f32[64, 64]"\n",
-       "      # t16 = prims.tanh(t15)  # t16: "cuda:0 f32[64, 64]"\n",
-       "  del t15\n",
-       "  t17 = torch.nn.functional.linear(t16, t14, None)  # t17: "cuda:0 f32[64, 64]"\n",
-       "    # t17 = ltorch.linear(t16, t14, None)  # t17: "cuda:0 f32[64, 64]"\n",
-       "      # t17 = prims.linear(t16, t14, None)  # t17: "cuda:0 f32[64, 64]"\n",
-       "  t18 = torch.add(t6, t7)  # t18: "cuda:0 f32[64, 64]"\n",
-       "    # t18 = ltorch.add(t6, t7, alpha=None)  # t18: "cuda:0 f32[64, 64]"\n",
-       "      # t18 = prims.add(t6, t7)  # t18: "cuda:0 f32[64, 64]"\n",
-       "  del t6, t7\n",
-       "  t19 = torch.add(t3, t8)  # t19: "cuda:0 f32[64, 64]"\n",
-       "    # t19 = ltorch.add(t3, t8, alpha=None)  # t19: "cuda:0 f32[64, 64]"\n",
-       "      # t19 = prims.add(t3, t8)  # t19: "cuda:0 f32[64, 64]"\n",
-       "  del t3, t8\n",
-       "  t20 = torch.add(t5, t9)  # t20: "cuda:0 f32[64, 64]"\n",
-       "    # t20 = ltorch.add(t5, t9, alpha=None)  # t20: "cuda:0 f32[64, 64]"\n",
-       "      # t20 = prims.add(t5, t9)  # t20: "cuda:0 f32[64, 64]"\n",
-       "  del t5, t9\n",
-       "  t21 = torch.reshape(t18, (-1, 64))  # t21: "cuda:0 f32[64, 64]"\n",
-       "    # t21 = ltorch.reshape(t18, (-1, 64))  # t21: "cuda:0 f32[64, 64]"\n",
-       "      # t21 = prims.reshape(t18, (64, 64))  # t21: "cuda:0 f32[64, 64]"\n",
-       "  t22 = torch.matmul(t21, t14)  # t22: "cuda:0 f32[64, 64]"\n",
-       "    # t22 = ltorch.matmul(t21, t14)  # t22: "cuda:0 f32[64, 64]"\n",
-       "      # t22 = prims.matmul(t21, t14)  # t22: "cuda:0 f32[64, 64]"\n",
-       "  del t21\n",
-       "  t23 = torch.reshape(t18, (-1, 64))  # t23: "cuda:0 f32[64, 64]"\n",
-       "    # t23 = ltorch.reshape(t18, (-1, 64))  # t23: "cuda:0 f32[64, 64]"\n",
-       "      # t23 = prims.reshape(t18, (64, 64))  # t23: "cuda:0 f32[64, 64]"\n",
-       "  del t18\n",
-       "  t24 = torch.permute(t23, (1, 0))  # t24: "cuda:0 f32[64, 64]"\n",
-       "    # t24 = ltorch.permute(t23, (1, 0))  # t24: "cuda:0 f32[64, 64]"\n",
-       "      # t24 = prims.transpose(t23, (1, 0))  # t24: "cuda:0 f32[64, 64]"\n",
-       "  del t23\n",
-       "  t25 = torch.reshape(t16, (-1, 64))  # t25: "cuda:0 f32[64, 64]"\n",
-       "    # t25 = ltorch.reshape(t16, (-1, 64))  # t25: "cuda:0 f32[64, 64]"\n",
-       "      # t25 = prims.reshape(t16, (64, 64))  # t25: "cuda:0 f32[64, 64]"\n",
-       "  t26 = torch.matmul(t24, t25)  # t26: "cuda:0 f32[64, 64]"\n",
-       "    # t26 = ltorch.matmul(t24, t25)  # t26: "cuda:0 f32[64, 64]"\n",
-       "      # t26 = prims.matmul(t24, t25)  # t26: "cuda:0 f32[64, 64]"\n",
-       "  del t24, t25\n",
-       "  t27 = torch.add(t10, t22)  # t27: "cuda:0 f32[64, 64]"\n",
-       "    # t27 = ltorch.add(t10, t22, alpha=None)  # t27: "cuda:0 f32[64, 64]"\n",
-       "      # t27 = prims.add(t10, t22)  # t27: "cuda:0 f32[64, 64]"\n",
-       "  del t10, t22\n",
-       "  t28 = torch.add(t20, t26)  # t28: "cuda:0 f32[64, 64]"\n",
-       "    # t28 = ltorch.add(t20, t26, alpha=None)  # t28: "cuda:0 f32[64, 64]"\n",
-       "      # t28 = prims.add(t20, t26)  # t28: "cuda:0 f32[64, 64]"\n",
-       "  del t20, t26\n",
-       "  t29 = torch.mul(t16, t16)  # t29: "cuda:0 f32[64, 64]"\n",
-       "    # t29 = ltorch.mul(t16, t16)  # t29: "cuda:0 f32[64, 64]"\n",
-       "      # t29 = prims.mul(t16, t16)  # t29: "cuda:0 f32[64, 64]"\n",
-       "  t30 = torch.sub(1, t29)  # t30: "cuda:0 f32[64, 64]"\n",
-       "    # t30 = ltorch.sub(1, t29, alpha=None)  # t30: "cuda:0 f32[64, 64]"\n",
-       "      # _ = prims.convert_element_type(1, float)\n",
-       "      # t30 = prims.sub(1.0, t29)  # t30: "cuda:0 f32[64, 64]"\n",
-       "  del t29\n",
-       "  t31 = torch.mul(t27, t30)  # t31: "cuda:0 f32[64, 64]"\n",
-       "    # t31 = ltorch.mul(t27, t30)  # t31: "cuda:0 f32[64, 64]"\n",
-       "      # t31 = prims.mul(t27, t30)  # t31: "cuda:0 f32[64, 64]"\n",
-       "  del t27, t30\n",
-       "  t32 = torch.reshape(t31, (-1, 64))  # t32: "cuda:0 f32[64, 64]"\n",
-       "    # t32 = ltorch.reshape(t31, (-1, 64))  # t32: "cuda:0 f32[64, 64]"\n",
-       "      # t32 = prims.reshape(t31, (64, 64))  # t32: "cuda:0 f32[64, 64]"\n",
-       "  t33 = torch.matmul(t32, t12)  # t33: "cuda:0 f32[64, 64]"\n",
-       "    # t33 = ltorch.matmul(t32, t12)  # t33: "cuda:0 f32[64, 64]"\n",
-       "      # t33 = prims.matmul(t32, t12)  # t33: "cuda:0 f32[64, 64]"\n",
-       "  del t32\n",
-       "  t34 = torch.reshape(t31, (-1, 64))  # t34: "cuda:0 f32[64, 64]"\n",
-       "    # t34 = ltorch.reshape(t31, (-1, 64))  # t34: "cuda:0 f32[64, 64]"\n",
-       "      # t34 = prims.reshape(t31, (64, 64))  # t34: "cuda:0 f32[64, 64]"\n",
-       "  del t31\n",
-       "  t35 = torch.permute(t34, (1, 0))  # t35: "cuda:0 f32[64, 64]"\n",
-       "    # t35 = ltorch.permute(t34, (1, 0))  # t35: "cuda:0 f32[64, 64]"\n",
-       "      # t35 = prims.transpose(t34, (1, 0))  # t35: "cuda:0 f32[64, 64]"\n",
-       "  del t34\n",
-       "  t36 = torch.reshape(t0, (-1, 64))  # t36: "cuda:0 f32[64, 64]"\n",
-       "    # t36 = ltorch.reshape(t0, (-1, 64))  # t36: "cuda:0 f32[64, 64]"\n",
-       "      # t36 = prims.reshape(t0, (64, 64))  # t36: "cuda:0 f32[64, 64]"\n",
-       "  t37 = torch.matmul(t35, t36)  # t37: "cuda:0 f32[64, 64]"\n",
-       "    # t37 = ltorch.matmul(t35, t36)  # t37: "cuda:0 f32[64, 64]"\n",
-       "      # t37 = prims.matmul(t35, t36)  # t37: "cuda:0 f32[64, 64]"\n",
-       "  del t35, t36\n",
-       "  t38 = torch.add(t19, t33)  # t38: "cuda:0 f32[64, 64]"\n",
-       "    # t38 = ltorch.add(t19, t33, alpha=None)  # t38: "cuda:0 f32[64, 64]"\n",
-       "      # t38 = prims.add(t19, t33)  # t38: "cuda:0 f32[64, 64]"\n",
-       "  del t19, t33\n",
-       "  t39 = torch.add(t4, t37)  # t39: "cuda:0 f32[64, 64]"\n",
-       "    # t39 = ltorch.add(t4, t37, alpha=None)  # t39: "cuda:0 f32[64, 64]"\n",
-       "      # t39 = prims.add(t4, t37)  # t39: "cuda:0 f32[64, 64]"\n",
-       "  del t4, t37\n",
-       "  t40 = torch.true_divide(t28, 2)  # t40: "cuda:0 f32[64, 64]"\n",
-       "    # t40 = ltorch.true_divide(t28, 2)  # t40: "cuda:0 f32[64, 64]"\n",
-       "      # _ = prims.convert_element_type(2, float)\n",
-       "      # t40 = prims.div(t28, 2.0)  # t40: "cuda:0 f32[64, 64]"\n",
-       "  del t28\n",
-       "  p41 = torch_reduce_scatter_prim_impl(t40, _DistributedReduceOps_3, _torch_distributed_distributed_c10d_ProcessGroup_2, True)  # p41: "FUTURE cuda:0 f32[32, 64]"\n",
-       "  del t40\n",
-       "  t42 = torch_wait_prim_impl(p41)  # t42: "cuda:0 f32[32, 64]"\n",
-       "  del p41\n",
-       "  t43 = torch.true_divide(t39, 2)  # t43: "cuda:0 f32[64, 64]"\n",
-       "    # t43 = ltorch.true_divide(t39, 2)  # t43: "cuda:0 f32[64, 64]"\n",
-       "      # _ = prims.convert_element_type(2, float)\n",
-       "      # t43 = prims.div(t39, 2.0)  # t43: "cuda:0 f32[64, 64]"\n",
-       "  del t39\n",
-       "  p44 = torch_reduce_scatter_prim_impl(t43, _DistributedReduceOps_3, _torch_distributed_distributed_c10d_ProcessGroup_2, True)  # p44: "FUTURE cuda:0 f32[32, 64]"\n",
-       "  del t43\n",
-       "  t45 = torch_wait_prim_impl(p44)  # t45: "cuda:0 f32[32, 64]"\n",
-       "  del p44\n",
-       "  return (({'output': t17, 'flat_args': [t0, t12, t14], 'flat_output': (t17,)}, ((t0, t14, t16), ())), (t38, t45, t42))\n",
-       "
\n" - ], - "text/latex": [ - "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", - "\\PY{c+c1}{\\PYZsh{} Constructed by Delete Last Used (took 0 milliseconds)}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{torch}\n", - "\\PY{k+kn}{import} \\PY{n+nn}{torch}\\PY{n+nn}{.}\\PY{n+nn}{nn}\\PY{n+nn}{.}\\PY{n+nn}{functional}\n", - "\\PY{k+kn}{from} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{executors}\\PY{n+nn}{.}\\PY{n+nn}{torchex} \\PY{k+kn}{import} \\PY{n}{no\\PYZus{}autocast}\n", - "\n", - "\\PY{n+nd}{@torch}\\PY{o}{.}\\PY{n}{no\\PYZus{}grad}\\PY{p}{(}\\PY{p}{)}\n", - "\\PY{n+nd}{@no\\PYZus{}autocast}\\PY{p}{(}\\PY{p}{)}\n", - "\\PY{k}{def} \\PY{n+nf}{\\PYZus{}value\\PYZus{}and\\PYZus{}grad}\\PY{p}{(}\\PY{o}{*}\\PY{n}{args}\\PY{p}{)}\\PY{p}{:}\n", - " \\PY{c+c1}{\\PYZsh{} args: \\PYZdq{}Collection\\PYZdq{} }\n", - " \\PY{n}{t0}\\PY{p}{,} \\PYZbs{}\n", - " \\PY{n}{t1}\\PY{p}{,} \\PYZbs{}\n", - " \\PY{n}{t2}\\PY{p}{,} \\PYZbs{}\n", - " \\PY{o}{=} \\PY{n}{args}\n", - " \\PY{k}{del} \\PY{n}{args}\n", - " \\PY{n}{t3} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t3 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t3 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t4} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t4: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t4 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t4: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t4 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t4: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t5} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t5 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t5 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t6} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t6: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t6 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t6: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t6 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t6: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t7} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t7: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t7 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t7: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t7 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t7: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t8} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t8: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t8 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t8: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t8 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t8: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t9} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t9: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t9 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t9: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t9 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t9: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t10} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t10: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t10 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t10: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t10 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t10: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{p11} \\PY{o}{=} \\PY{n}{torch\\PYZus{}all\\PYZus{}gather\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{t1}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}2}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p11: \\PYZdq{}FUTURE cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{k}{del} \\PY{n}{t1}\n", - " \\PY{n}{t12} \\PY{o}{=} \\PY{n}{torch\\PYZus{}wait\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{p11}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t12: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{k}{del} \\PY{n}{p11}\n", - " \\PY{n}{p13} \\PY{o}{=} \\PY{n}{torch\\PYZus{}all\\PYZus{}gather\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{t2}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}2}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p13: \\PYZdq{}FUTURE cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{k}{del} \\PY{n}{t2}\n", - " \\PY{n}{t14} \\PY{o}{=} \\PY{n}{torch\\PYZus{}wait\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{p13}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t14: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{k}{del} \\PY{n}{p13}\n", - " \\PY{n}{t15} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{nn}\\PY{o}{.}\\PY{n}{functional}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t12}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t15: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t15 = ltorch.linear(t0, t12, None) \\PYZsh{} t15: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t15 = prims.linear(t0, t12, None) \\PYZsh{} t15: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t16} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{tanh}\\PY{p}{(}\\PY{n}{t15}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t16: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t16 = ltorch.tanh(t15) \\PYZsh{} t16: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t16 = prims.tanh(t15) \\PYZsh{} t16: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{k}{del} \\PY{n}{t15}\n", - " \\PY{n}{t17} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{nn}\\PY{o}{.}\\PY{n}{functional}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t16}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t17: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t17 = ltorch.linear(t16, t14, None) \\PYZsh{} t17: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t17 = prims.linear(t16, t14, None) \\PYZsh{} t17: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t18} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t6}\\PY{p}{,} \\PY{n}{t7}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t18: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t18 = ltorch.add(t6, t7, alpha=None) \\PYZsh{} t18: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t18 = prims.add(t6, t7) \\PYZsh{} t18: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{k}{del} \\PY{n}{t6}\\PY{p}{,} \\PY{n}{t7}\n", - " \\PY{n}{t19} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t3}\\PY{p}{,} \\PY{n}{t8}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t19: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t19 = ltorch.add(t3, t8, alpha=None) \\PYZsh{} t19: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t19 = prims.add(t3, t8) \\PYZsh{} t19: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{k}{del} \\PY{n}{t3}\\PY{p}{,} \\PY{n}{t8}\n", - " \\PY{n}{t20} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t5}\\PY{p}{,} \\PY{n}{t9}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t20: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t20 = ltorch.add(t5, t9, alpha=None) \\PYZsh{} t20: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t20 = prims.add(t5, t9) \\PYZsh{} t20: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{k}{del} \\PY{n}{t5}\\PY{p}{,} \\PY{n}{t9}\n", - " \\PY{n}{t21} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t18}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t21: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t21 = ltorch.reshape(t18, (\\PYZhy{}1, 64)) \\PYZsh{} t21: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t21 = prims.reshape(t18, (64, 64)) \\PYZsh{} t21: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t22} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t21}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t22: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t22 = ltorch.matmul(t21, t14) \\PYZsh{} t22: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t22 = prims.matmul(t21, t14) \\PYZsh{} t22: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{k}{del} \\PY{n}{t21}\n", - " \\PY{n}{t23} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t18}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t23: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t23 = ltorch.reshape(t18, (\\PYZhy{}1, 64)) \\PYZsh{} t23: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t23 = prims.reshape(t18, (64, 64)) \\PYZsh{} t23: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{k}{del} \\PY{n}{t18}\n", - " \\PY{n}{t24} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{permute}\\PY{p}{(}\\PY{n}{t23}\\PY{p}{,} \\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t24: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t24 = ltorch.permute(t23, (1, 0)) \\PYZsh{} t24: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t24 = prims.transpose(t23, (1, 0)) \\PYZsh{} t24: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{k}{del} \\PY{n}{t23}\n", - " \\PY{n}{t25} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t16}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t25: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t25 = ltorch.reshape(t16, (\\PYZhy{}1, 64)) \\PYZsh{} t25: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t25 = prims.reshape(t16, (64, 64)) \\PYZsh{} t25: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t26} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t24}\\PY{p}{,} \\PY{n}{t25}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t26: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t26 = ltorch.matmul(t24, t25) \\PYZsh{} t26: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t26 = prims.matmul(t24, t25) \\PYZsh{} t26: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{k}{del} \\PY{n}{t24}\\PY{p}{,} \\PY{n}{t25}\n", - " \\PY{n}{t27} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t10}\\PY{p}{,} \\PY{n}{t22}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t27: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t27 = ltorch.add(t10, t22, alpha=None) \\PYZsh{} t27: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t27 = prims.add(t10, t22) \\PYZsh{} t27: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{k}{del} \\PY{n}{t10}\\PY{p}{,} \\PY{n}{t22}\n", - " \\PY{n}{t28} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t20}\\PY{p}{,} \\PY{n}{t26}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t28: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t28 = ltorch.add(t20, t26, alpha=None) \\PYZsh{} t28: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t28 = prims.add(t20, t26) \\PYZsh{} t28: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{k}{del} \\PY{n}{t20}\\PY{p}{,} \\PY{n}{t26}\n", - " \\PY{n}{t29} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{mul}\\PY{p}{(}\\PY{n}{t16}\\PY{p}{,} \\PY{n}{t16}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t29: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t29 = ltorch.mul(t16, t16) \\PYZsh{} t29: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t29 = prims.mul(t16, t16) \\PYZsh{} t29: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t30} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{sub}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{t29}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t30: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t30 = ltorch.sub(1, t29, alpha=None) \\PYZsh{} t30: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} \\PYZus{} = prims.convert\\PYZus{}element\\PYZus{}type(1, float)}\n", - " \\PY{c+c1}{\\PYZsh{} t30 = prims.sub(1.0, t29) \\PYZsh{} t30: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{k}{del} \\PY{n}{t29}\n", - " \\PY{n}{t31} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{mul}\\PY{p}{(}\\PY{n}{t27}\\PY{p}{,} \\PY{n}{t30}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t31: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t31 = ltorch.mul(t27, t30) \\PYZsh{} t31: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t31 = prims.mul(t27, t30) \\PYZsh{} t31: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{k}{del} \\PY{n}{t27}\\PY{p}{,} \\PY{n}{t30}\n", - " \\PY{n}{t32} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t31}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t32: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t32 = ltorch.reshape(t31, (\\PYZhy{}1, 64)) \\PYZsh{} t32: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t32 = prims.reshape(t31, (64, 64)) \\PYZsh{} t32: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t33} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t32}\\PY{p}{,} \\PY{n}{t12}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t33: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t33 = ltorch.matmul(t32, t12) \\PYZsh{} t33: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t33 = prims.matmul(t32, t12) \\PYZsh{} t33: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{k}{del} \\PY{n}{t32}\n", - " \\PY{n}{t34} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t31}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t34: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t34 = ltorch.reshape(t31, (\\PYZhy{}1, 64)) \\PYZsh{} t34: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t34 = prims.reshape(t31, (64, 64)) \\PYZsh{} t34: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{k}{del} \\PY{n}{t31}\n", - " \\PY{n}{t35} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{permute}\\PY{p}{(}\\PY{n}{t34}\\PY{p}{,} \\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t35: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t35 = ltorch.permute(t34, (1, 0)) \\PYZsh{} t35: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t35 = prims.transpose(t34, (1, 0)) \\PYZsh{} t35: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{k}{del} \\PY{n}{t34}\n", - " \\PY{n}{t36} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t36: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t36 = ltorch.reshape(t0, (\\PYZhy{}1, 64)) \\PYZsh{} t36: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t36 = prims.reshape(t0, (64, 64)) \\PYZsh{} t36: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{n}{t37} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t35}\\PY{p}{,} \\PY{n}{t36}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t37: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t37 = ltorch.matmul(t35, t36) \\PYZsh{} t37: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t37 = prims.matmul(t35, t36) \\PYZsh{} t37: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{k}{del} \\PY{n}{t35}\\PY{p}{,} \\PY{n}{t36}\n", - " \\PY{n}{t38} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t19}\\PY{p}{,} \\PY{n}{t33}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t38: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t38 = ltorch.add(t19, t33, alpha=None) \\PYZsh{} t38: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t38 = prims.add(t19, t33) \\PYZsh{} t38: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{k}{del} \\PY{n}{t19}\\PY{p}{,} \\PY{n}{t33}\n", - " \\PY{n}{t39} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t4}\\PY{p}{,} \\PY{n}{t37}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t39: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t39 = ltorch.add(t4, t37, alpha=None) \\PYZsh{} t39: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t39 = prims.add(t4, t37) \\PYZsh{} t39: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{k}{del} \\PY{n}{t4}\\PY{p}{,} \\PY{n}{t37}\n", - " \\PY{n}{t40} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{true\\PYZus{}divide}\\PY{p}{(}\\PY{n}{t28}\\PY{p}{,} \\PY{l+m+mi}{2}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t40: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t40 = ltorch.true\\PYZus{}divide(t28, 2) \\PYZsh{} t40: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} \\PYZus{} = prims.convert\\PYZus{}element\\PYZus{}type(2, float)}\n", - " \\PY{c+c1}{\\PYZsh{} t40 = prims.div(t28, 2.0) \\PYZsh{} t40: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{k}{del} \\PY{n}{t28}\n", - " \\PY{n}{p41} \\PY{o}{=} \\PY{n}{torch\\PYZus{}reduce\\PYZus{}scatter\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{t40}\\PY{p}{,} \\PY{n}{\\PYZus{}DistributedReduceOps\\PYZus{}3}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}2}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p41: \\PYZdq{}FUTURE cuda:0 f32[32, 64]\\PYZdq{}}\n", - " \\PY{k}{del} \\PY{n}{t40}\n", - " \\PY{n}{t42} \\PY{o}{=} \\PY{n}{torch\\PYZus{}wait\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{p41}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t42: \\PYZdq{}cuda:0 f32[32, 64]\\PYZdq{}}\n", - " \\PY{k}{del} \\PY{n}{p41}\n", - " \\PY{n}{t43} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{true\\PYZus{}divide}\\PY{p}{(}\\PY{n}{t39}\\PY{p}{,} \\PY{l+m+mi}{2}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t43: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} t43 = ltorch.true\\PYZus{}divide(t39, 2) \\PYZsh{} t43: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{c+c1}{\\PYZsh{} \\PYZus{} = prims.convert\\PYZus{}element\\PYZus{}type(2, float)}\n", - " \\PY{c+c1}{\\PYZsh{} t43 = prims.div(t39, 2.0) \\PYZsh{} t43: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n", - " \\PY{k}{del} \\PY{n}{t39}\n", - " \\PY{n}{p44} \\PY{o}{=} \\PY{n}{torch\\PYZus{}reduce\\PYZus{}scatter\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{t43}\\PY{p}{,} \\PY{n}{\\PYZus{}DistributedReduceOps\\PYZus{}3}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}2}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p44: \\PYZdq{}FUTURE cuda:0 f32[32, 64]\\PYZdq{}}\n", - " \\PY{k}{del} \\PY{n}{t43}\n", - " \\PY{n}{t45} \\PY{o}{=} \\PY{n}{torch\\PYZus{}wait\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{p44}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t45: \\PYZdq{}cuda:0 f32[32, 64]\\PYZdq{}}\n", - " \\PY{k}{del} \\PY{n}{p44}\n", - " \\PY{k}{return} \\PY{p}{(}\\PY{p}{(}\\PY{p}{\\PYZob{}}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{n}{t17}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}args}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{[}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t12}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{]}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{(}\\PY{n}{t17}\\PY{p}{,}\\PY{p}{)}\\PY{p}{\\PYZcb{}}\\PY{p}{,} \\PY{p}{(}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{,} \\PY{n}{t16}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{n}{t38}\\PY{p}{,} \\PY{n}{t45}\\PY{p}{,} \\PY{n}{t42}\\PY{p}{)}\\PY{p}{)}\n", - "\\end{Verbatim}\n" - ], - "text/plain": [ - "# Constructed by Delete Last Used (took 0 milliseconds)\n", - "import torch\n", - "import torch.nn.functional\n", - "from thunder.executors.torchex import no_autocast\n", - "\n", - "@torch.no_grad()\n", - "@no_autocast()\n", - "def _value_and_grad(*args):\n", - " # args: \"Collection\" \n", - " t0, \\\n", - " t1, \\\n", - " t2, \\\n", - " = args\n", - " del args\n", - " t3 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t3: \"cuda:0 f32[64, 64]\"\n", - " # t3 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t3: \"cuda:0 f32[64, 64]\"\n", - " # t3 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t3: \"cuda:0 f32[64, 64]\"\n", - " t4 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t4: \"cuda:0 f32[64, 64]\"\n", - " # t4 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t4: \"cuda:0 f32[64, 64]\"\n", - " # t4 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t4: \"cuda:0 f32[64, 64]\"\n", - " t5 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t5: \"cuda:0 f32[64, 64]\"\n", - " # t5 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t5: \"cuda:0 f32[64, 64]\"\n", - " # t5 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t5: \"cuda:0 f32[64, 64]\"\n", - " t6 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t6: \"cuda:0 f32[64, 64]\"\n", - " # t6 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t6: \"cuda:0 f32[64, 64]\"\n", - " # t6 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t6: \"cuda:0 f32[64, 64]\"\n", - " t7 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t7: \"cuda:0 f32[64, 64]\"\n", - " # t7 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t7: \"cuda:0 f32[64, 64]\"\n", - " # t7 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t7: \"cuda:0 f32[64, 64]\"\n", - " t8 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t8: \"cuda:0 f32[64, 64]\"\n", - " # t8 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t8: \"cuda:0 f32[64, 64]\"\n", - " # t8 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t8: \"cuda:0 f32[64, 64]\"\n", - " t9 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t9: \"cuda:0 f32[64, 64]\"\n", - " # t9 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t9: \"cuda:0 f32[64, 64]\"\n", - " # t9 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t9: \"cuda:0 f32[64, 64]\"\n", - " t10 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t10: \"cuda:0 f32[64, 64]\"\n", - " # t10 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t10: \"cuda:0 f32[64, 64]\"\n", - " # t10 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t10: \"cuda:0 f32[64, 64]\"\n", - " p11 = torch_all_gather_prim_impl(t1, _torch_distributed_distributed_c10d_ProcessGroup_2, True) # p11: \"FUTURE cuda:0 f32[64, 64]\"\n", - " del t1\n", - " t12 = torch_wait_prim_impl(p11) # t12: \"cuda:0 f32[64, 64]\"\n", - " del p11\n", - " p13 = torch_all_gather_prim_impl(t2, _torch_distributed_distributed_c10d_ProcessGroup_2, True) # p13: \"FUTURE cuda:0 f32[64, 64]\"\n", - " del t2\n", - " t14 = torch_wait_prim_impl(p13) # t14: \"cuda:0 f32[64, 64]\"\n", - " del p13\n", - " t15 = torch.nn.functional.linear(t0, t12, None) # t15: \"cuda:0 f32[64, 64]\"\n", - " # t15 = ltorch.linear(t0, t12, None) # t15: \"cuda:0 f32[64, 64]\"\n", - " # t15 = prims.linear(t0, t12, None) # t15: \"cuda:0 f32[64, 64]\"\n", - " t16 = torch.tanh(t15) # t16: \"cuda:0 f32[64, 64]\"\n", - " # t16 = ltorch.tanh(t15) # t16: \"cuda:0 f32[64, 64]\"\n", - " # t16 = prims.tanh(t15) # t16: \"cuda:0 f32[64, 64]\"\n", - " del t15\n", - " t17 = torch.nn.functional.linear(t16, t14, None) # t17: \"cuda:0 f32[64, 64]\"\n", - " # t17 = ltorch.linear(t16, t14, None) # t17: \"cuda:0 f32[64, 64]\"\n", - " # t17 = prims.linear(t16, t14, None) # t17: \"cuda:0 f32[64, 64]\"\n", - " t18 = torch.add(t6, t7) # t18: \"cuda:0 f32[64, 64]\"\n", - " # t18 = ltorch.add(t6, t7, alpha=None) # t18: \"cuda:0 f32[64, 64]\"\n", - " # t18 = prims.add(t6, t7) # t18: \"cuda:0 f32[64, 64]\"\n", - " del t6, t7\n", - " t19 = torch.add(t3, t8) # t19: \"cuda:0 f32[64, 64]\"\n", - " # t19 = ltorch.add(t3, t8, alpha=None) # t19: \"cuda:0 f32[64, 64]\"\n", - " # t19 = prims.add(t3, t8) # t19: \"cuda:0 f32[64, 64]\"\n", - " del t3, t8\n", - " t20 = torch.add(t5, t9) # t20: \"cuda:0 f32[64, 64]\"\n", - " # t20 = ltorch.add(t5, t9, alpha=None) # t20: \"cuda:0 f32[64, 64]\"\n", - " # t20 = prims.add(t5, t9) # t20: \"cuda:0 f32[64, 64]\"\n", - " del t5, t9\n", - " t21 = torch.reshape(t18, (-1, 64)) # t21: \"cuda:0 f32[64, 64]\"\n", - " # t21 = ltorch.reshape(t18, (-1, 64)) # t21: \"cuda:0 f32[64, 64]\"\n", - " # t21 = prims.reshape(t18, (64, 64)) # t21: \"cuda:0 f32[64, 64]\"\n", - " t22 = torch.matmul(t21, t14) # t22: \"cuda:0 f32[64, 64]\"\n", - " # t22 = ltorch.matmul(t21, t14) # t22: \"cuda:0 f32[64, 64]\"\n", - " # t22 = prims.matmul(t21, t14) # t22: \"cuda:0 f32[64, 64]\"\n", - " del t21\n", - " t23 = torch.reshape(t18, (-1, 64)) # t23: \"cuda:0 f32[64, 64]\"\n", - " # t23 = ltorch.reshape(t18, (-1, 64)) # t23: \"cuda:0 f32[64, 64]\"\n", - " # t23 = prims.reshape(t18, (64, 64)) # t23: \"cuda:0 f32[64, 64]\"\n", - " del t18\n", - " t24 = torch.permute(t23, (1, 0)) # t24: \"cuda:0 f32[64, 64]\"\n", - " # t24 = ltorch.permute(t23, (1, 0)) # t24: \"cuda:0 f32[64, 64]\"\n", - " # t24 = prims.transpose(t23, (1, 0)) # t24: \"cuda:0 f32[64, 64]\"\n", - " del t23\n", - " t25 = torch.reshape(t16, (-1, 64)) # t25: \"cuda:0 f32[64, 64]\"\n", - " # t25 = ltorch.reshape(t16, (-1, 64)) # t25: \"cuda:0 f32[64, 64]\"\n", - " # t25 = prims.reshape(t16, (64, 64)) # t25: \"cuda:0 f32[64, 64]\"\n", - " t26 = torch.matmul(t24, t25) # t26: \"cuda:0 f32[64, 64]\"\n", - " # t26 = ltorch.matmul(t24, t25) # t26: \"cuda:0 f32[64, 64]\"\n", - " # t26 = prims.matmul(t24, t25) # t26: \"cuda:0 f32[64, 64]\"\n", - " del t24, t25\n", - " t27 = torch.add(t10, t22) # t27: \"cuda:0 f32[64, 64]\"\n", - " # t27 = ltorch.add(t10, t22, alpha=None) # t27: \"cuda:0 f32[64, 64]\"\n", - " # t27 = prims.add(t10, t22) # t27: \"cuda:0 f32[64, 64]\"\n", - " del t10, t22\n", - " t28 = torch.add(t20, t26) # t28: \"cuda:0 f32[64, 64]\"\n", - " # t28 = ltorch.add(t20, t26, alpha=None) # t28: \"cuda:0 f32[64, 64]\"\n", - " # t28 = prims.add(t20, t26) # t28: \"cuda:0 f32[64, 64]\"\n", - " del t20, t26\n", - " t29 = torch.mul(t16, t16) # t29: \"cuda:0 f32[64, 64]\"\n", - " # t29 = ltorch.mul(t16, t16) # t29: \"cuda:0 f32[64, 64]\"\n", - " # t29 = prims.mul(t16, t16) # t29: \"cuda:0 f32[64, 64]\"\n", - " t30 = torch.sub(1, t29) # t30: \"cuda:0 f32[64, 64]\"\n", - " # t30 = ltorch.sub(1, t29, alpha=None) # t30: \"cuda:0 f32[64, 64]\"\n", - " # _ = prims.convert_element_type(1, float)\n", - " # t30 = prims.sub(1.0, t29) # t30: \"cuda:0 f32[64, 64]\"\n", - " del t29\n", - " t31 = torch.mul(t27, t30) # t31: \"cuda:0 f32[64, 64]\"\n", - " # t31 = ltorch.mul(t27, t30) # t31: \"cuda:0 f32[64, 64]\"\n", - " # t31 = prims.mul(t27, t30) # t31: \"cuda:0 f32[64, 64]\"\n", - " del t27, t30\n", - " t32 = torch.reshape(t31, (-1, 64)) # t32: \"cuda:0 f32[64, 64]\"\n", - " # t32 = ltorch.reshape(t31, (-1, 64)) # t32: \"cuda:0 f32[64, 64]\"\n", - " # t32 = prims.reshape(t31, (64, 64)) # t32: \"cuda:0 f32[64, 64]\"\n", - " t33 = torch.matmul(t32, t12) # t33: \"cuda:0 f32[64, 64]\"\n", - " # t33 = ltorch.matmul(t32, t12) # t33: \"cuda:0 f32[64, 64]\"\n", - " # t33 = prims.matmul(t32, t12) # t33: \"cuda:0 f32[64, 64]\"\n", - " del t32\n", - " t34 = torch.reshape(t31, (-1, 64)) # t34: \"cuda:0 f32[64, 64]\"\n", - " # t34 = ltorch.reshape(t31, (-1, 64)) # t34: \"cuda:0 f32[64, 64]\"\n", - " # t34 = prims.reshape(t31, (64, 64)) # t34: \"cuda:0 f32[64, 64]\"\n", - " del t31\n", - " t35 = torch.permute(t34, (1, 0)) # t35: \"cuda:0 f32[64, 64]\"\n", - " # t35 = ltorch.permute(t34, (1, 0)) # t35: \"cuda:0 f32[64, 64]\"\n", - " # t35 = prims.transpose(t34, (1, 0)) # t35: \"cuda:0 f32[64, 64]\"\n", - " del t34\n", - " t36 = torch.reshape(t0, (-1, 64)) # t36: \"cuda:0 f32[64, 64]\"\n", - " # t36 = ltorch.reshape(t0, (-1, 64)) # t36: \"cuda:0 f32[64, 64]\"\n", - " # t36 = prims.reshape(t0, (64, 64)) # t36: \"cuda:0 f32[64, 64]\"\n", - " t37 = torch.matmul(t35, t36) # t37: \"cuda:0 f32[64, 64]\"\n", - " # t37 = ltorch.matmul(t35, t36) # t37: \"cuda:0 f32[64, 64]\"\n", - " # t37 = prims.matmul(t35, t36) # t37: \"cuda:0 f32[64, 64]\"\n", - " del t35, t36\n", - " t38 = torch.add(t19, t33) # t38: \"cuda:0 f32[64, 64]\"\n", - " # t38 = ltorch.add(t19, t33, alpha=None) # t38: \"cuda:0 f32[64, 64]\"\n", - " # t38 = prims.add(t19, t33) # t38: \"cuda:0 f32[64, 64]\"\n", - " del t19, t33\n", - " t39 = torch.add(t4, t37) # t39: \"cuda:0 f32[64, 64]\"\n", - " # t39 = ltorch.add(t4, t37, alpha=None) # t39: \"cuda:0 f32[64, 64]\"\n", - " # t39 = prims.add(t4, t37) # t39: \"cuda:0 f32[64, 64]\"\n", - " del t4, t37\n", - " t40 = torch.true_divide(t28, 2) # t40: \"cuda:0 f32[64, 64]\"\n", - " # t40 = ltorch.true_divide(t28, 2) # t40: \"cuda:0 f32[64, 64]\"\n", - " # _ = prims.convert_element_type(2, float)\n", - " # t40 = prims.div(t28, 2.0) # t40: \"cuda:0 f32[64, 64]\"\n", - " del t28\n", - " p41 = torch_reduce_scatter_prim_impl(t40, _DistributedReduceOps_3, _torch_distributed_distributed_c10d_ProcessGroup_2, True) # p41: \"FUTURE cuda:0 f32[32, 64]\"\n", - " del t40\n", - " t42 = torch_wait_prim_impl(p41) # t42: \"cuda:0 f32[32, 64]\"\n", - " del p41\n", - " t43 = torch.true_divide(t39, 2) # t43: \"cuda:0 f32[64, 64]\"\n", - " # t43 = ltorch.true_divide(t39, 2) # t43: \"cuda:0 f32[64, 64]\"\n", - " # _ = prims.convert_element_type(2, float)\n", - " # t43 = prims.div(t39, 2.0) # t43: \"cuda:0 f32[64, 64]\"\n", - " del t39\n", - " p44 = torch_reduce_scatter_prim_impl(t43, _DistributedReduceOps_3, _torch_distributed_distributed_c10d_ProcessGroup_2, True) # p44: \"FUTURE cuda:0 f32[32, 64]\"\n", - " del t43\n", - " t45 = torch_wait_prim_impl(p44) # t45: \"cuda:0 f32[32, 64]\"\n", - " del p44\n", - " return (({'output': t17, 'flat_args': [t0, t12, t14], 'flat_output': (t17,)}, ((t0, t14, t16), ())), (t38, t45, t42))" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "optimized_trace = thunder.transform_for_execution(forward_backward_trace, executors_list=thunder.get_always_executors())\n", "\n", @@ -1749,17 +333,9 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Overwriting thunder_fsdp_simple_example.py\n" - ] - } - ], + "outputs": [], "source": [ "%%writefile thunder_fsdp_simple_example.py\n", "\n", @@ -1841,148 +417,9 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "W0314 08:26:39.130000 140292199276608 torch/distributed/run.py:757] \n", - "W0314 08:26:39.130000 140292199276608 torch/distributed/run.py:757] *****************************************\n", - "W0314 08:26:39.130000 140292199276608 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. \n", - "W0314 08:26:39.130000 140292199276608 torch/distributed/run.py:757] *****************************************\n", - "# Constructed by Delete Last Used (took 0 milliseconds)\n", - "import torch\n", - "import torch.nn.functional\n", - "from thunder.executors.torchex import no_autocast\n", - "\n", - "@torch.no_grad()\n", - "@no_autocast()\n", - "def augmented_forward_fn(input, t_0_bias, t_2_bias, t_0_weight, t_2_weight):\n", - " # input: \"cuda:0 f32[64, 64]\" \n", - " # t_0_bias: \"cuda:0 f32[32]\" \n", - " p0 = torch_all_gather_prim_impl(t_0_bias, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p0: \"FUTURE cuda:0 f32[64]\"\n", - " # t_2_bias: \"cuda:0 f32[32]\" \n", - " p2 = torch_all_gather_prim_impl(t_2_bias, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p2: \"FUTURE cuda:0 f32[64]\"\n", - " # t_0_weight: \"cuda:0 f32[32, 64]\" \n", - " p4 = torch_all_gather_prim_impl(t_0_weight, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p4: \"FUTURE cuda:0 f32[64, 64]\"\n", - " # t_2_weight: \"cuda:0 f32[32, 64]\" \n", - " p9 = torch_all_gather_prim_impl(t_2_weight, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p9: \"FUTURE cuda:0 f32[64, 64]\"\n", - " t1 = torch_wait_prim_impl(p0) # t1: \"cuda:0 f32[64]\"\n", - " del p0\n", - " t3 = torch_wait_prim_impl(p2) # t3: \"cuda:0 f32[64]\"\n", - " del p2\n", - " t5 = torch_wait_prim_impl(p4) # t5: \"cuda:0 f32[64, 64]\"\n", - " del p4\n", - " t6 = torch.nn.functional.linear(input, t5, t1) # t6: \"cuda:0 f32[64, 64]\"\n", - " # t6 = ltorch.linear(input, t5, t1) # t6: \"cuda:0 f32[64, 64]\"\n", - " # t6 = prims.linear(input, t5, t1) # t6: \"cuda:0 f32[64, 64]\"\n", - " del t5, t1\n", - " [t7, t8] = nvFusion0(t6)\n", - " # t7 = prims.gt(t6, 0.0) # t7: \"cuda:0 b8[64, 64]\"\n", - " # t8 = prims.where(t7, t6, 0.0) # t8: \"cuda:0 f32[64, 64]\"\n", - " del t6\n", - " t10 = torch_wait_prim_impl(p9) # t10: \"cuda:0 f32[64, 64]\"\n", - " del p9\n", - " t11 = torch.nn.functional.linear(t8, t10, t3) # t11: \"cuda:0 f32[64, 64]\"\n", - " # t11 = ltorch.linear(t8, t10, t3) # t11: \"cuda:0 f32[64, 64]\"\n", - " # t11 = prims.linear(t8, t10, t3) # t11: \"cuda:0 f32[64, 64]\"\n", - " del t3\n", - " return {'output': t11, 'flat_args': [input, t_0_bias, t_2_bias, t_0_weight, t_2_weight], 'flat_output': (t11,)}, ((input, t10, t7, t8), ())\n", - "********************************************************\n", - "# Constructed by Delete Last Used (took 0 milliseconds)\n", - "import torch\n", - "from thunder.executors.torchex import no_autocast\n", - "\n", - "@torch.no_grad()\n", - "@no_autocast()\n", - "def backward_fn(saved_for_backward, cotangents):\n", - " # saved_for_backward: \"Collection\" \n", - " # cotangents: \"Collection\" \n", - " C0, \\\n", - " _, \\\n", - " = saved_for_backward\n", - " clear_collection(saved_for_backward)\n", - " del saved_for_backward\n", - " t0, \\\n", - " = cotangents\n", - " clear_collection(cotangents)\n", - " del cotangents\n", - " input, \\\n", - " t10, \\\n", - " t7, \\\n", - " t8, \\\n", - " = C0\n", - " clear_collection(C0)\n", - " del C0\n", - " t31 = torch.reshape(t0, (-1, 64)) # t31: \"cuda:0 f32[64, 64]\"\n", - " # t31 = ltorch.reshape(t0, (-1, 64)) # t31: \"cuda:0 f32[64, 64]\"\n", - " # t31 = prims.reshape(t0, (64, 64)) # t31: \"cuda:0 f32[64, 64]\"\n", - " t32 = torch.permute(t31, (1, 0)) # t32: \"cuda:0 f32[64, 64]\"\n", - " # t32 = ltorch.permute(t31, (1, 0)) # t32: \"cuda:0 f32[64, 64]\"\n", - " # t32 = prims.transpose(t31, (1, 0)) # t32: \"cuda:0 f32[64, 64]\"\n", - " t33 = torch.reshape(t8, (-1, 64)) # t33: \"cuda:0 f32[64, 64]\"\n", - " # t33 = ltorch.reshape(t8, (-1, 64)) # t33: \"cuda:0 f32[64, 64]\"\n", - " # t33 = prims.reshape(t8, (64, 64)) # t33: \"cuda:0 f32[64, 64]\"\n", - " del t8\n", - " t45 = torch.reshape(input, (-1, 64)) # t45: \"cuda:0 f32[64, 64]\"\n", - " # t45 = ltorch.reshape(input, (-1, 64)) # t45: \"cuda:0 f32[64, 64]\"\n", - " # t45 = prims.reshape(input, (64, 64)) # t45: \"cuda:0 f32[64, 64]\"\n", - " del input\n", - " [t51] = nvFusion0(t0)\n", - " # t35 = prims.sum(t0, (0,)) # t35: \"cuda:0 f32[64]\"\n", - " # t51 = prims.div(t35, 2.0) # t51: \"cuda:0 f32[64]\"\n", - " del t0\n", - " p52 = torch_reduce_scatter_prim_impl(t51, _DistributedReduceOps_0, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p52: \"FUTURE cuda:0 f32[32]\"\n", - " del t51\n", - " t30 = torch.matmul(t31, t10) # t30: \"cuda:0 f32[64, 64]\"\n", - " # t30 = ltorch.matmul(t29, t10) # t30: \"cuda:0 f32[64, 64]\"\n", - " # t30 = prims.matmul(t29, t10) # t30: \"cuda:0 f32[64, 64]\"\n", - " del t31, t10\n", - " t34 = torch.matmul(t32, t33) # t34: \"cuda:0 f32[64, 64]\"\n", - " # t34 = ltorch.matmul(t32, t33) # t34: \"cuda:0 f32[64, 64]\"\n", - " # t34 = prims.matmul(t32, t33) # t34: \"cuda:0 f32[64, 64]\"\n", - " del t32, t33\n", - " [t36, t39, t54] = nvFusion1(t30, t34, t7)\n", - " # t39 = prims.where(t7, t30, 0.0) # t39: \"cuda:0 f32[64, 64]\"\n", - " # t47 = prims.sum(t39, (0,)) # t47: \"cuda:0 f32[64]\"\n", - " # t54 = prims.div(t47, 2.0) # t54: \"cuda:0 f32[64]\"\n", - " # t36 = prims.div(t34, 2.0) # t36: \"cuda:0 f32[64, 64]\"\n", - " del t30, t34, t7\n", - " p37 = torch_reduce_scatter_prim_impl(t36, _DistributedReduceOps_0, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p37: \"FUTURE cuda:0 f32[32, 64]\"\n", - " del t36\n", - " p55 = torch_reduce_scatter_prim_impl(t54, _DistributedReduceOps_0, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p55: \"FUTURE cuda:0 f32[32]\"\n", - " del t54\n", - " t43 = torch.reshape(t39, (-1, 64)) # t43: \"cuda:0 f32[64, 64]\"\n", - " # t43 = ltorch.reshape(t39, (-1, 64)) # t43: \"cuda:0 f32[64, 64]\"\n", - " # t43 = prims.reshape(t39, (64, 64)) # t43: \"cuda:0 f32[64, 64]\"\n", - " del t39\n", - " t44 = torch.permute(t43, (1, 0)) # t44: \"cuda:0 f32[64, 64]\"\n", - " # t44 = ltorch.permute(t43, (1, 0)) # t44: \"cuda:0 f32[64, 64]\"\n", - " # t44 = prims.transpose(t43, (1, 0)) # t44: \"cuda:0 f32[64, 64]\"\n", - " del t43\n", - " t46 = torch.matmul(t44, t45) # t46: \"cuda:0 f32[64, 64]\"\n", - " # t46 = ltorch.matmul(t44, t45) # t46: \"cuda:0 f32[64, 64]\"\n", - " # t46 = prims.matmul(t44, t45) # t46: \"cuda:0 f32[64, 64]\"\n", - " del t44, t45\n", - " [t48] = nvFusion2(t46)\n", - " # t48 = prims.div(t46, 2.0) # t48: \"cuda:0 f32[64, 64]\"\n", - " del t46\n", - " p49 = torch_reduce_scatter_prim_impl(t48, _DistributedReduceOps_0, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p49: \"FUTURE cuda:0 f32[32, 64]\"\n", - " del t48\n", - " t53 = torch_wait_prim_impl(p52) # t53: \"cuda:0 f32[32]\"\n", - " del p52\n", - " t38 = torch_wait_prim_impl(p37) # t38: \"cuda:0 f32[32, 64]\"\n", - " del p37\n", - " t56 = torch_wait_prim_impl(p55) # t56: \"cuda:0 f32[32]\"\n", - " del p55\n", - " t50 = torch_wait_prim_impl(p49) # t50: \"cuda:0 f32[32, 64]\"\n", - " del p49\n", - " return (None, t56, t53, t50, t38)\n" - ] - } - ], + "outputs": [], "source": [ "!torchrun --nproc_per_node=2 thunder_fsdp_simple_example.py" ] @@ -2017,7 +454,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.10.13" } }, "nbformat": 4, diff --git a/requirements/docs.txt b/requirements/docs.txt index 14615a08c4..69547efb64 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -7,8 +7,10 @@ docutils >=0.16 sphinxcontrib-fulltoc ==1.2.0 sphinxcontrib-mockautodoc -pt-lightning-sphinx-theme @ https://github.com/Lightning-AI/lightning_sphinx_theme/archive/master.zip sphinx-autodoc-typehints ==1.23.0 sphinx-paramlinks ==0.6.0 sphinx-togglebutton ==0.3.2 sphinx-copybutton ==0.5.2 + +# installed from S3 location and fetched in advance +lai-sphinx-theme diff --git a/thunder/__init__.py b/thunder/__init__.py index 21ba6499ec..3386590a1f 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -44,7 +44,7 @@ from thunder.core.compile_data import compile_data_and_stats from thunder.core.langctxs import LanguageContext import thunder.core.langctxs as langctxs -from thunder.core.baseutils import run_once +from thunder.core.baseutils import run_once, check from thunder.core.proxies import ( Proxy, TensorProxy, @@ -563,25 +563,38 @@ def get_computation_and_inputs(*args, **kwargs): # thunder_backward may recursively call compile and wraps the result in a # torch.autograd.Function to support embedding of Thunder-compiled # functions in torch's Autograd + + # Currently split_forward_backward also includes + # transform_for_execution and various sorting of symbols, + # applying transform_for_execution after this would be + # breaking the order of operations computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps) # Note computation_trc and backward_trc have been appended to cs.last_(backward_)traces # by split_forward_backward - - cs.last_computation_transformation_start = time.time_ns() - - ## EPILOGUE and TRANSFORMS should not mix... - # applies transforms - for transform in additional_transforms: - computation_trc = transform(computation_trc, executors_list=cd.executors_list) - computation_traces.append(computation_trc) - - with langctxs.langctx(cd.langctx): - extraces = transform_for_execution( - computation_trc, - executors_list=cd.executors_list, - ) - extrace = extraces[-1] - comp = extrace.python_callable() + extraces = cs.last_traces + check( + not additional_transforms, + lambda: "Specifying additional_transforms is not supported with PyTorch Autograd integration", + ) + + if backward_trc is None: + cs.last_computation_transformation_start = time.time_ns() + + ## EPILOGUE and TRANSFORMS should not mix... + # applies transforms + for transform in additional_transforms: + computation_trc = transform(computation_trc, executors_list=cd.executors_list) + computation_traces.append(computation_trc) + + with langctxs.langctx(cd.langctx): + extraces = transform_for_execution( + computation_trc, + executors_list=cd.executors_list, + ) + computation_trc = extraces[-1] + cs.last_computation_transformation_stop = time.time_ns() + + comp = computation_trc.python_callable() if backward_trc is not None: backward_fn = backward_trc.python_callable() @@ -595,7 +608,6 @@ def get_computation_and_inputs(*args, **kwargs): if cd.cache_option is not CACHE_OPTIONS.NO_CACHING: cs.interpreter_cache.append(cache_entry) - cs.last_computation_transformation_stop = time.time_ns() cs.last_traces += extraces cs.last_prologue_traces = [prologue_trc] + protraces cs.last_prologue = pro diff --git a/thunder/benchmarks/distributed.py b/thunder/benchmarks/distributed.py index 3437c47e30..bc02528fe7 100644 --- a/thunder/benchmarks/distributed.py +++ b/thunder/benchmarks/distributed.py @@ -322,9 +322,11 @@ def parse_args() -> argparse.Namespace: ResultFormatter( model_name=args.model, base_name="torch_fsdp", - suffix=str(sharding_strategy).lower() + "-bucketing_" + "block" - if auto_wrap_policy is not None - else "none", + suffix=( + str(sharding_strategy).lower() + "-bucketing_" + "block" + if auto_wrap_policy is not None + else "none" + ), dtype=args.dtype, world_size=world_size, total_callable_construction_time=total_cct, @@ -352,9 +354,11 @@ def parse_args() -> argparse.Namespace: ResultFormatter( model_name=args.model, base_name="torch_compile_fsdp", - suffix=str(sharding_strategy).lower() + "-bucketing_" + "block" - if auto_wrap_policy is not None - else "none", + suffix=( + str(sharding_strategy).lower() + "-bucketing_" + "block" + if auto_wrap_policy is not None + else "none" + ), dtype=args.dtype, world_size=world_size, total_callable_construction_time=total_cct, diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 6c0a6ffc0c..bbba0a08cb 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -377,7 +377,9 @@ class ThunderSharpEdgeError(RuntimeError): def _sharp_edge(desc: str, value: Any, /) -> Any | INTERPRETER_SIGNALS: sharp_edges: SHARP_EDGES_OPTIONS = get_minimal_ctx().sharp_edges - s: str = f"{desc} is a sharp edge that cannot be translated to a thunder program unless using interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON." + s: str = ( + f"{desc} is a sharp edge that cannot be translated to a thunder program unless using interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON." + ) if sharp_edges is SHARP_EDGES_OPTIONS.ERROR: return do_raise(ThunderSharpEdgeError(s)) @@ -469,7 +471,9 @@ class JITSharpEdgeError(RuntimeError): def _general_jit_sharp_edge(desc: str, value: Any, /) -> Any | INTERPRETER_SIGNALS: sharp_edges: SHARP_EDGES_OPTIONS = get_minimal_ctx().sharp_edges - s: str = f"{desc} This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!" + s: str = ( + f"{desc} This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!" + ) if sharp_edges is SHARP_EDGES_OPTIONS.ERROR: return do_raise(JITSharpEdgeError(s)) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 772e65a84d..3bd25c3297 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -3802,7 +3802,7 @@ def augmented_forward_fn(*args, **kwargs): # Copy the signature of the original function so that the arguments are # named correctly in the augmented forward pass instead of being named # "args" and "kwargs". - augmented_forward_fn.__signature__ = inspect.signature(trace.fn) + augmented_forward_fn.__signature__ = inspect.signature(trace.fn or trace.python_callable()) def ones_like(x): if isinstance(x, TensorProxy): diff --git a/thunder/core/utils.py b/thunder/core/utils.py index 1edebe70e4..f0d4764fc4 100644 --- a/thunder/core/utils.py +++ b/thunder/core/utils.py @@ -743,16 +743,13 @@ class FrozenDict(_UserDictT[T, T1], Mapping[T, T1]): """ @overload - def __init__(self, data: Mapping[T, T1]) -> None: - ... + def __init__(self, data: Mapping[T, T1]) -> None: ... @overload - def __init__(self, data: Iterable[T, T1]) -> None: - ... + def __init__(self, data: Iterable[T, T1]) -> None: ... @overload - def __init__(self, *args: Any, **kwargs: Any) -> None: - ... + def __init__(self, *args: Any, **kwargs: Any) -> None: ... def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -834,8 +831,7 @@ def _safe_zip_gen(*args): @overload -def safe_zip(x: Iterable[T], y: Iterable[T1], /) -> Iterable[tuple[T, T1]]: - ... +def safe_zip(x: Iterable[T], y: Iterable[T1], /) -> Iterable[tuple[T, T1]]: ... def safe_zip(*args): diff --git a/thunder/core/vjp_utils.py b/thunder/core/vjp_utils.py index 3c8128ffeb..8910dc3110 100644 --- a/thunder/core/vjp_utils.py +++ b/thunder/core/vjp_utils.py @@ -1,11 +1,12 @@ import inspect +from collections.abc import Callable +from functools import wraps from inspect import Parameter, Signature from itertools import chain -from collections.abc import Callable from thunder.core import prims, utils from thunder.core.prims import PrimIDs -from thunder.core.proxies import variableify, Proxy +from thunder.core.proxies import Proxy, variableify from thunder.core.pytree import tree_flatten, tree_map from thunder.core.symbol import BoundSymbol from thunder.core.trace import from_trace, TraceCtx @@ -15,6 +16,16 @@ _cache = {} +def disable_caching_split_forward_and_backward(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + wrapper._disable_caching = True + + return wrapper + + def make_aug_forward_and_backward(bsym: BoundSymbol) -> tuple[Callable, Callable]: """ Given a bound symbol, return a pair of forward and backward functions @@ -46,7 +57,7 @@ def make_aug_forward_and_backward(bsym: BoundSymbol) -> tuple[Callable, Callable key = (bsym.sym, subkey := _make_cache_key(bsym.args, bsym.kwargs)) cached_result = _cache.get(key, None) if subkey is not None else None - if cached_result is not None: + if cached_result is not None and not getattr(joint_forward_backward, "_disable_caching", False): return cached_result joint_trace = thunder.trace(inline_trace=False, use_dce=False)(joint_forward_backward, *bsym.args, **bsym.kwargs) diff --git a/thunder/distributed/utils.py b/thunder/distributed/utils.py index 9d9fa7bf42..48cc66d4c8 100644 --- a/thunder/distributed/utils.py +++ b/thunder/distributed/utils.py @@ -84,12 +84,12 @@ def prefer_comm_over_other_over_wait_over_allgather(eligible_nodes: list[Node]) # nodes over "wait_prim_impl", pick "all_gather_prim_impl" last. def key(node: Node) -> int: match node.bsym.sym.id: - case (wait_prim_impl.id | unpack_for_fsdp_prim_impl.id): + case wait_prim_impl.id | unpack_for_fsdp_prim_impl.id: return len(order_in_trace) - case (reduce_scatter_prim_impl.id | all_reduce_prim_impl.id): + case reduce_scatter_prim_impl.id | all_reduce_prim_impl.id: # Prefer larger communication ops over smaller ones return -node.bsym.args[0].numel - case (all_gather_prim_impl.id): + case all_gather_prim_impl.id: return len(order_in_trace) + order_in_trace[node.bsym] case _: # Prefer nodes that are earlier in the trace @@ -141,9 +141,9 @@ def prefer_comm_over_other_over_wait(eligible_nodes: list[Node]) -> int: # nodes over "wait_prim_impl" def key(node: Node) -> int: match node.bsym.sym.id: - case (wait_prim_impl.id): + case wait_prim_impl.id: return len(order_in_trace) - case (reduce_scatter_prim_impl.id | all_reduce_prim_impl.id | all_gather_prim_impl.id): + case reduce_scatter_prim_impl.id | all_reduce_prim_impl.id | all_gather_prim_impl.id: # Prefer larger communication ops over smaller ones return -node.bsym.args[0].numel case _: diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 93c0c2d874..f256ad338b 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -202,7 +202,7 @@ def make_trace(func): if not any(requires_grad_mask): raise RuntimeError("PyTorch's Autograd interface requires at least one tensor input with requires_grad=True") - primal_trace = make_trace(func)(*args, **kwargs) + primal_trace = make_trace(func)(*args, **kwargs) if not compile_data.using_jit else computation_trc primal_trace = sort_data_parallel_syncs(primal_trace) if compile_stats is not None: diff --git a/thunder/executors/transformer_engineex.py b/thunder/executors/transformer_engineex.py index ba53c3e940..c1842f11ca 100644 --- a/thunder/executors/transformer_engineex.py +++ b/thunder/executors/transformer_engineex.py @@ -19,6 +19,7 @@ import thunder.core.prims as prims from thunder.core.proxies import TensorProxy, CollectionProxy from thunder.core.symbol import Symbol +from thunder.core.vjp_utils import disable_caching_split_forward_and_backward from thunder.extend import OperatorExecutor, register_executor from thunder.core.langctxs import langctx, Languages @@ -412,6 +413,7 @@ def _linear_transform(a: TensorProxy, w: TensorProxy, b: TensorProxy) -> torch.T return _create_fp8_linear_bound_symbol(a, w, b, is_grad_enabled=False) +@disable_caching_split_forward_and_backward def _linear_grad(a: TensorProxy, w: TensorProxy, b: TensorProxy) -> TensorProxy: out, saved_for_backward = linear_forwad_rule(a, w, b) g = prims.get_grad(out) diff --git a/thunder/tests/litgpt_model.py b/thunder/tests/litgpt_model.py index 23b51545a0..13ab52f44b 100644 --- a/thunder/tests/litgpt_model.py +++ b/thunder/tests/litgpt_model.py @@ -1,4 +1,5 @@ """Taken from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py""" + import torch import torch.nn as nn diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index 54b27f4964..377a284a6c 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -2265,6 +2265,34 @@ def func(x, y, *, z): assert bsym.flat_args == [1, 2, 3] +@instantiate(dtypes=NOTHING) +def test_preserve_weight_names(executor, device: str, dtype: dtypes.dtype): + import inspect + + class MLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(3, 4) + self.fc2 = torch.nn.Linear(4, 5) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + + model = MLP().to(device=device, dtype=ltorch.to_torch_dtype(dtype)) + x = torch.randn(2, 3, device=device, dtype=ltorch.to_torch_dtype(dtype)) + + compiled = thunder.jit(model, executors=executor.executors_list()) + compiled(x) + traces = thunder.last_traces(compiled) + sig = inspect.signature(traces[-1].python_callable()) + assert "t_fc1_bias" in sig.parameters + assert "t_fc1_weight" in sig.parameters + assert "t_fc2_bias" in sig.parameters + assert "t_fc2_weight" in sig.parameters + + # @instantiate( # dtypes=NOTHING, # ) diff --git a/thunder/tests/test_transformer_engine_executor.py b/thunder/tests/test_transformer_engine_executor.py index 41b6f83e36..e8abcc7c09 100644 --- a/thunder/tests/test_transformer_engine_executor.py +++ b/thunder/tests/test_transformer_engine_executor.py @@ -33,7 +33,7 @@ def test_te_linear_forward_backward(): # TE inputs (3D input) x_te = torch.randn(3, 768, 4096, device=device, dtype=dtype, requires_grad=True) te_linear1 = te.Linear(4096, 4096, params_dtype=dtype) - te_linear2 = te.Linear(4096, 2048, params_dtype=dtype) + te_linear2 = te.Linear(4096, 4096, params_dtype=dtype) # thunder inputs x = x_te.detach().clone()