diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml
index b8d87e466d..4dd1365477 100644
--- a/.github/workflows/docs-build.yml
+++ b/.github/workflows/docs-build.yml
@@ -15,11 +15,51 @@ defaults:
jobs:
docs-make:
- uses: Lightning-AI/utilities/.github/workflows/check-docs.yml@v0.11.0
- with:
- python-version: "3.10"
- requirements-file: "requirements/docs.txt"
- install-tex: true
+ if: github.event.pull_request.draft == false
+ runs-on: ubuntu-22.04
+ strategy:
+ fail-fast: false
+ matrix:
+ target: ["html", "doctest", "linkcheck"]
+ env:
+ ARTIFACT_DAYS: 0
+ PYPI_LOCAL_DIR: "pypi_pkgs/"
+ steps:
+ - uses: actions/checkout@v4
+ - uses: actions/setup-python@v5
+ with:
+ python-version: "3.10"
+
+ - name: Pull sphinx template
+ run: |
+ pip install -q "awscli >=1.30.0"
+ aws s3 sync --no-sign-request s3://sphinx-packages/ ${PYPI_LOCAL_DIR}
+ pip install lai-sphinx-theme -U -f ${PYPI_LOCAL_DIR}
+ - name: Install pandoc
+ timeout-minutes: 5
+ run: |
+ sudo apt-get update --fix-missing
+ sudo apt-get install -y pandoc
+ - name: Install package & dependencies
+ timeout-minutes: 20
+ run: pip install . -U -r requirements/docs.txt
+
+ - name: Make ${{ matrix.target }}
+ working-directory: docs/
+ # allow failing link check and doctest if you run with dispatch
+ continue-on-error: ${{ matrix.target == 'doctest' || matrix.target == 'linkcheck' }}
+ run: make ${{ matrix.target }} --debug --jobs $(nproc) SPHINXOPTS="-W --keep-going"
+
+ - name: Keep artifact
+ if: github.event_name == 'pull_request'
+ run: echo "ARTIFACT_DAYS=7" >> $GITHUB_ENV
+ - name: Upload built docs
+ if: ${{ matrix.target == 'html' }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: docs-html-${{ github.sha }}
+ path: docs/build/html/
+ retention-days: ${{ env.ARTIFACT_DAYS }}
deploy-docs:
needs: docs-make
@@ -28,7 +68,7 @@ jobs:
env:
GCP_TARGET: "gs://lightning-docs-thunder"
steps:
- - uses: actions/download-artifact@v3
+ - uses: actions/download-artifact@v4
with:
name: docs-html-${{ github.sha }}
path: docs/build/
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index f16e1ef98f..b3e1716d00 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -8,7 +8,7 @@ ci:
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v4.4.0
+ rev: v4.5.0
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
@@ -24,7 +24,7 @@ repos:
- id: detect-private-key
- repo: https://github.com/asottile/pyupgrade
- rev: v3.11.1
+ rev: v3.15.2
hooks:
- id: pyupgrade
args: ["--py310-plus"]
@@ -32,14 +32,14 @@ repos:
exclude: "examples|thunder/tests/test_interpreter.py|thunder/tests/test_jit_general.py"
- repo: https://github.com/codespell-project/codespell
- rev: v2.2.5
+ rev: v2.2.6
hooks:
- id: codespell
additional_dependencies: [tomli]
#args: ["--write-changes"] # uncomment if you want to get automatic fixing
- repo: https://github.com/psf/black
- rev: 23.9.1
+ rev: 24.3.0
hooks:
- id: black
name: Black code
@@ -61,7 +61,7 @@ repos:
- id: sphinx-lint
- repo: https://github.com/asottile/yesqa
- rev: v1.4.0
+ rev: v1.5.0
hooks:
- id: yesqa
diff --git a/Makefile b/Makefile
index 677ca6b1b9..2c27c12f5c 100644
--- a/Makefile
+++ b/Makefile
@@ -10,7 +10,13 @@ test: clean
python -m coverage run --source thunder -m pytest thunder tests -v
python -m coverage report
-docs: clean
+sphinx-theme:
+ pip install -q awscli
+ mkdir -p dist/
+ aws s3 sync --no-sign-request s3://sphinx-packages/ dist/
+ pip install lai-sphinx-theme -f dist/
+
+docs: clean sphinx-theme
pip install -e . --quiet -r requirements/docs.txt -f https://download.pytorch.org/whl/cpu/torch_stable.html
cd docs ; python -m sphinx -b html -W --keep-going source build
diff --git a/README.md b/README.md
index a672abe48e..3e33bf5bd4 100644
--- a/README.md
+++ b/README.md
@@ -14,8 +14,9 @@ ______________________________________________________________________
Get started •
Install •
Examples •
- Features •
- Documentation •
+ Inside Thunder •
+ Get involved! •
+ Documentation
[![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/Lightning-AI/lightning-thunder/blob/main/LICENSE)
@@ -30,41 +31,58 @@ ______________________________________________________________________
**Thunder makes PyTorch models Lightning fast.**
-Thunder is a source-to-source compiler for PyTorch. It makes PyTorch programs faster by combining and using different hardware executors at once (ie: nvFuser, torch.compile, cuDNN, and TransformerEngine FP8).
+Thunder is a source-to-source compiler for PyTorch. It makes PyTorch programs faster by combining and using different hardware executors at once (for instance, [nvFuser](https://github.com/NVIDIA/Fuser), [torch.compile](https://pytorch.org/docs/stable/torch.compiler.html), [cuDNN](https://developer.nvidia.com/cudnn), and [TransformerEngine FP8](https://github.com/NVIDIA/TransformerEngine)).
-Works on single accelerators and in multi-GPU settings.
+It supports both single and multi-GPU configurations.
Thunder aims to be usable, understandable, and extensible.
-## Performance
+
-Thunder can achieve significant speedups over standard PyTorch eager code, through the compounding effects of optimizations and the use of best-in-class executors. Here is an example of the pretraining throughput for Llama 2 7B as implemented in [LitGPT](https://github.com/Lightning-AI/litgpt).
+> \[!Note\]
+> Lightning Thunder is in alpha. Feel free to get involved, but expect a few bumps along the way.
+
+
+
+## Single-GPU performance
+
+Thunder can achieve significant speedups over standard non-compiled PyTorch code ("PyTorch eager"), through the compounding effects of optimizations and the use of best-in-class executors. The figure below shows the pretraining throughput for Llama 2 7B as implemented in [LitGPT](https://github.com/Lightning-AI/litgpt).
# Constructed by Dead Code Elimination (took 1 milliseconds)\n",
- "import thunder\n",
- "import thunder.core.devices as devices\n",
- "import thunder.core.dtypes as dtypes\n",
- "import thunder.core.prims as prims\n",
- "import thunder.distributed.prims\n",
- "import thunder.torch as ltorch\n",
- "import torch\n",
- "from thunder.executors.torchex import no_autocast\n",
- "\n",
- "@torch.no_grad()\n",
- "@no_autocast()\n",
- "def _value_and_grad(*args):\n",
- " # args: "Collection" \n",
- " t0, \\\n",
- " t1, \\\n",
- " t2, \\\n",
- " = args\n",
- " t3 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32) # t3: "cuda:0 f32[64, 64]"\n",
- " t4 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32) # t4: "cuda:0 f32[64, 64]"\n",
- " t5 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32) # t5: "cuda:0 f32[64, 64]"\n",
- " t6 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32) # t6: "cuda:0 f32[64, 64]"\n",
- " t7 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32) # t7: "cuda:0 f32[64, 64]"\n",
- " t8 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32) # t8: "cuda:0 f32[64, 64]"\n",
- " t9 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32) # t9: "cuda:0 f32[64, 64]"\n",
- " t10 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32) # t10: "cuda:0 f32[64, 64]"\n",
- " p11 = thunder.distributed.prims.all_gather(t1, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p11: "FUTURE cuda:0 f32[64, 64]"\n",
- " t12 = thunder.distributed.prims.wait(p11) # t12: "cuda:0 f32[64, 64]"\n",
- " p13 = thunder.distributed.prims.all_gather(t2, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p13: "FUTURE cuda:0 f32[64, 64]"\n",
- " t14 = thunder.distributed.prims.wait(p13) # t14: "cuda:0 f32[64, 64]"\n",
- " t15 = prims.linear(t0, t12, None) # t15: "cuda:0 f32[64, 64]"\n",
- " t16 = prims.tanh(t15) # t16: "cuda:0 f32[64, 64]"\n",
- " t17 = prims.linear(t16, t14, None) # t17: "cuda:0 f32[64, 64]"\n",
- " t18 = prims.add(t6, t7) # t18: "cuda:0 f32[64, 64]"\n",
- " t19 = prims.add(t3, t8) # t19: "cuda:0 f32[64, 64]"\n",
- " t20 = prims.add(t5, t9) # t20: "cuda:0 f32[64, 64]"\n",
- " t21 = ltorch.reshape(t18, -1, 64) # t21: "cuda:0 f32[64, 64]"\n",
- " # t21 = prims.reshape(t18, (64, 64)) # t21: "cuda:0 f32[64, 64]"\n",
- " t22 = ltorch.matmul(t21, t14) # t22: "cuda:0 f32[64, 64]"\n",
- " # t22 = prims.matmul(t21, t14) # t22: "cuda:0 f32[64, 64]"\n",
- " t23 = ltorch.reshape(t18, -1, 64) # t23: "cuda:0 f32[64, 64]"\n",
- " # t23 = prims.reshape(t18, (64, 64)) # t23: "cuda:0 f32[64, 64]"\n",
- " t24 = prims.transpose(t23, (1, 0)) # t24: "cuda:0 f32[64, 64]"\n",
- " t25 = ltorch.reshape(t16, -1, 64) # t25: "cuda:0 f32[64, 64]"\n",
- " # t25 = prims.reshape(t16, (64, 64)) # t25: "cuda:0 f32[64, 64]"\n",
- " t26 = ltorch.matmul(t24, t25) # t26: "cuda:0 f32[64, 64]"\n",
- " # t26 = prims.matmul(t24, t25) # t26: "cuda:0 f32[64, 64]"\n",
- " t27 = prims.add(t10, t22) # t27: "cuda:0 f32[64, 64]"\n",
- " t28 = prims.add(t20, t26) # t28: "cuda:0 f32[64, 64]"\n",
- " t29 = ltorch.mul(t16, t16) # t29: "cuda:0 f32[64, 64]"\n",
- " # t29 = prims.mul(t16, t16) # t29: "cuda:0 f32[64, 64]"\n",
- " t30 = ltorch.sub(1, t29, alpha=None) # t30: "cuda:0 f32[64, 64]"\n",
- " # _ = prims.convert_element_type(1, float)\n",
- " # t30 = prims.sub(1.0, t29) # t30: "cuda:0 f32[64, 64]"\n",
- " t31 = ltorch.mul(t27, t30) # t31: "cuda:0 f32[64, 64]"\n",
- " # t31 = prims.mul(t27, t30) # t31: "cuda:0 f32[64, 64]"\n",
- " t32 = ltorch.reshape(t31, -1, 64) # t32: "cuda:0 f32[64, 64]"\n",
- " # t32 = prims.reshape(t31, (64, 64)) # t32: "cuda:0 f32[64, 64]"\n",
- " t33 = ltorch.matmul(t32, t12) # t33: "cuda:0 f32[64, 64]"\n",
- " # t33 = prims.matmul(t32, t12) # t33: "cuda:0 f32[64, 64]"\n",
- " t34 = ltorch.reshape(t31, -1, 64) # t34: "cuda:0 f32[64, 64]"\n",
- " # t34 = prims.reshape(t31, (64, 64)) # t34: "cuda:0 f32[64, 64]"\n",
- " t35 = prims.transpose(t34, (1, 0)) # t35: "cuda:0 f32[64, 64]"\n",
- " t36 = ltorch.reshape(t0, -1, 64) # t36: "cuda:0 f32[64, 64]"\n",
- " # t36 = prims.reshape(t0, (64, 64)) # t36: "cuda:0 f32[64, 64]"\n",
- " t37 = ltorch.matmul(t35, t36) # t37: "cuda:0 f32[64, 64]"\n",
- " # t37 = prims.matmul(t35, t36) # t37: "cuda:0 f32[64, 64]"\n",
- " t38 = prims.add(t19, t33) # t38: "cuda:0 f32[64, 64]"\n",
- " t39 = prims.add(t4, t37) # t39: "cuda:0 f32[64, 64]"\n",
- " t40 = ltorch.true_divide(t28, 2) # t40: "cuda:0 f32[64, 64]"\n",
- " # _ = prims.convert_element_type(2, float)\n",
- " # t40 = prims.div(t28, 2.0) # t40: "cuda:0 f32[64, 64]"\n",
- " p41 = thunder.distributed.prims.reduce_scatter(t40, _DistributedReduceOps_1, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p41: "FUTURE cuda:0 f32[32, 64]"\n",
- " t42 = thunder.distributed.prims.wait(p41) # t42: "cuda:0 f32[32, 64]"\n",
- " t43 = ltorch.true_divide(t39, 2) # t43: "cuda:0 f32[64, 64]"\n",
- " # _ = prims.convert_element_type(2, float)\n",
- " # t43 = prims.div(t39, 2.0) # t43: "cuda:0 f32[64, 64]"\n",
- " p44 = thunder.distributed.prims.reduce_scatter(t43, _DistributedReduceOps_1, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p44: "FUTURE cuda:0 f32[32, 64]"\n",
- " t45 = thunder.distributed.prims.wait(p44) # t45: "cuda:0 f32[32, 64]"\n",
- " return (({'output': t17, 'flat_args': [t0, t12, t14], 'flat_output': (t17,)}, ((t0, t14, t16), ())), (t38, t45, t42))\n",
- "
\n"
- ],
- "text/latex": [
- "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n",
- "\\PY{c+c1}{\\PYZsh{} Constructed by Dead Code Elimination (took 1 milliseconds)}\n",
- "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\n",
- "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{core}\\PY{n+nn}{.}\\PY{n+nn}{devices} \\PY{k}{as} \\PY{n+nn}{devices}\n",
- "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{core}\\PY{n+nn}{.}\\PY{n+nn}{dtypes} \\PY{k}{as} \\PY{n+nn}{dtypes}\n",
- "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{core}\\PY{n+nn}{.}\\PY{n+nn}{prims} \\PY{k}{as} \\PY{n+nn}{prims}\n",
- "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{distributed}\\PY{n+nn}{.}\\PY{n+nn}{prims}\n",
- "\\PY{k+kn}{import} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{torch} \\PY{k}{as} \\PY{n+nn}{ltorch}\n",
- "\\PY{k+kn}{import} \\PY{n+nn}{torch}\n",
- "\\PY{k+kn}{from} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{executors}\\PY{n+nn}{.}\\PY{n+nn}{torchex} \\PY{k+kn}{import} \\PY{n}{no\\PYZus{}autocast}\n",
- "\n",
- "\\PY{n+nd}{@torch}\\PY{o}{.}\\PY{n}{no\\PYZus{}grad}\\PY{p}{(}\\PY{p}{)}\n",
- "\\PY{n+nd}{@no\\PYZus{}autocast}\\PY{p}{(}\\PY{p}{)}\n",
- "\\PY{k}{def} \\PY{n+nf}{\\PYZus{}value\\PYZus{}and\\PYZus{}grad}\\PY{p}{(}\\PY{o}{*}\\PY{n}{args}\\PY{p}{)}\\PY{p}{:}\n",
- " \\PY{c+c1}{\\PYZsh{} args: \\PYZdq{}Collection\\PYZdq{} }\n",
- " \\PY{n}{t0}\\PY{p}{,} \\PYZbs{}\n",
- " \\PY{n}{t1}\\PY{p}{,} \\PYZbs{}\n",
- " \\PY{n}{t2}\\PY{p}{,} \\PYZbs{}\n",
- " \\PY{o}{=} \\PY{n}{args}\n",
- " \\PY{n}{t3} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t4} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t4: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t5} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t6} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t6: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t7} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t7: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t8} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t8: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t9} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t9: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t10} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{devices}\\PY{o}{.}\\PY{n}{Device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{dtypes}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t10: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{p11} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{all\\PYZus{}gather}\\PY{p}{(}\\PY{n}{t1}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p11: \\PYZdq{}FUTURE cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t12} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{wait}\\PY{p}{(}\\PY{n}{p11}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t12: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{p13} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{all\\PYZus{}gather}\\PY{p}{(}\\PY{n}{t2}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p13: \\PYZdq{}FUTURE cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t14} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{wait}\\PY{p}{(}\\PY{n}{p13}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t14: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t15} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t12}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t15: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t16} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{tanh}\\PY{p}{(}\\PY{n}{t15}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t16: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t17} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t16}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t17: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t18} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t6}\\PY{p}{,} \\PY{n}{t7}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t18: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t19} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t3}\\PY{p}{,} \\PY{n}{t8}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t19: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t20} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t5}\\PY{p}{,} \\PY{n}{t9}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t20: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t21} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t18}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t21: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t21 = prims.reshape(t18, (64, 64)) \\PYZsh{} t21: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t22} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t21}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t22: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t22 = prims.matmul(t21, t14) \\PYZsh{} t22: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t23} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t18}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t23: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t23 = prims.reshape(t18, (64, 64)) \\PYZsh{} t23: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t24} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{transpose}\\PY{p}{(}\\PY{n}{t23}\\PY{p}{,} \\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t24: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t25} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t16}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t25: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t25 = prims.reshape(t16, (64, 64)) \\PYZsh{} t25: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t26} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t24}\\PY{p}{,} \\PY{n}{t25}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t26: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t26 = prims.matmul(t24, t25) \\PYZsh{} t26: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t27} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t10}\\PY{p}{,} \\PY{n}{t22}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t27: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t28} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t20}\\PY{p}{,} \\PY{n}{t26}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t28: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t29} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{mul}\\PY{p}{(}\\PY{n}{t16}\\PY{p}{,} \\PY{n}{t16}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t29: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t29 = prims.mul(t16, t16) \\PYZsh{} t29: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t30} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{sub}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{t29}\\PY{p}{,} \\PY{n}{alpha}\\PY{o}{=}\\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t30: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} \\PYZus{} = prims.convert\\PYZus{}element\\PYZus{}type(1, float)}\n",
- " \\PY{c+c1}{\\PYZsh{} t30 = prims.sub(1.0, t29) \\PYZsh{} t30: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t31} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{mul}\\PY{p}{(}\\PY{n}{t27}\\PY{p}{,} \\PY{n}{t30}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t31: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t31 = prims.mul(t27, t30) \\PYZsh{} t31: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t32} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t31}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t32: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t32 = prims.reshape(t31, (64, 64)) \\PYZsh{} t32: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t33} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t32}\\PY{p}{,} \\PY{n}{t12}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t33: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t33 = prims.matmul(t32, t12) \\PYZsh{} t33: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t34} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t31}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t34: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t34 = prims.reshape(t31, (64, 64)) \\PYZsh{} t34: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t35} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{transpose}\\PY{p}{(}\\PY{n}{t34}\\PY{p}{,} \\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t35: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t36} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t36: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t36 = prims.reshape(t0, (64, 64)) \\PYZsh{} t36: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t37} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t35}\\PY{p}{,} \\PY{n}{t36}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t37: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t37 = prims.matmul(t35, t36) \\PYZsh{} t37: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t38} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t19}\\PY{p}{,} \\PY{n}{t33}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t38: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t39} \\PY{o}{=} \\PY{n}{prims}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t4}\\PY{p}{,} \\PY{n}{t37}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t39: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t40} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{true\\PYZus{}divide}\\PY{p}{(}\\PY{n}{t28}\\PY{p}{,} \\PY{l+m+mi}{2}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t40: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} \\PYZus{} = prims.convert\\PYZus{}element\\PYZus{}type(2, float)}\n",
- " \\PY{c+c1}{\\PYZsh{} t40 = prims.div(t28, 2.0) \\PYZsh{} t40: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{p41} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{reduce\\PYZus{}scatter}\\PY{p}{(}\\PY{n}{t40}\\PY{p}{,} \\PY{n}{\\PYZus{}DistributedReduceOps\\PYZus{}1}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p41: \\PYZdq{}FUTURE cuda:0 f32[32, 64]\\PYZdq{}}\n",
- " \\PY{n}{t42} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{wait}\\PY{p}{(}\\PY{n}{p41}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t42: \\PYZdq{}cuda:0 f32[32, 64]\\PYZdq{}}\n",
- " \\PY{n}{t43} \\PY{o}{=} \\PY{n}{ltorch}\\PY{o}{.}\\PY{n}{true\\PYZus{}divide}\\PY{p}{(}\\PY{n}{t39}\\PY{p}{,} \\PY{l+m+mi}{2}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t43: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} \\PYZus{} = prims.convert\\PYZus{}element\\PYZus{}type(2, float)}\n",
- " \\PY{c+c1}{\\PYZsh{} t43 = prims.div(t39, 2.0) \\PYZsh{} t43: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{p44} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{reduce\\PYZus{}scatter}\\PY{p}{(}\\PY{n}{t43}\\PY{p}{,} \\PY{n}{\\PYZus{}DistributedReduceOps\\PYZus{}1}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}0}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p44: \\PYZdq{}FUTURE cuda:0 f32[32, 64]\\PYZdq{}}\n",
- " \\PY{n}{t45} \\PY{o}{=} \\PY{n}{thunder}\\PY{o}{.}\\PY{n}{distributed}\\PY{o}{.}\\PY{n}{prims}\\PY{o}{.}\\PY{n}{wait}\\PY{p}{(}\\PY{n}{p44}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t45: \\PYZdq{}cuda:0 f32[32, 64]\\PYZdq{}}\n",
- " \\PY{k}{return} \\PY{p}{(}\\PY{p}{(}\\PY{p}{\\PYZob{}}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{n}{t17}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}args}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{[}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t12}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{]}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{(}\\PY{n}{t17}\\PY{p}{,}\\PY{p}{)}\\PY{p}{\\PYZcb{}}\\PY{p}{,} \\PY{p}{(}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{,} \\PY{n}{t16}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{n}{t38}\\PY{p}{,} \\PY{n}{t45}\\PY{p}{,} \\PY{n}{t42}\\PY{p}{)}\\PY{p}{)}\n",
- "\\end{Verbatim}\n"
- ],
- "text/plain": [
- "# Constructed by Dead Code Elimination (took 1 milliseconds)\n",
- "import thunder\n",
- "import thunder.core.devices as devices\n",
- "import thunder.core.dtypes as dtypes\n",
- "import thunder.core.prims as prims\n",
- "import thunder.distributed.prims\n",
- "import thunder.torch as ltorch\n",
- "import torch\n",
- "from thunder.executors.torchex import no_autocast\n",
- "\n",
- "@torch.no_grad()\n",
- "@no_autocast()\n",
- "def _value_and_grad(*args):\n",
- " # args: \"Collection\" \n",
- " t0, \\\n",
- " t1, \\\n",
- " t2, \\\n",
- " = args\n",
- " t3 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t3: \"cuda:0 f32[64, 64]\"\n",
- " t4 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t4: \"cuda:0 f32[64, 64]\"\n",
- " t5 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t5: \"cuda:0 f32[64, 64]\"\n",
- " t6 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t6: \"cuda:0 f32[64, 64]\"\n",
- " t7 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t7: \"cuda:0 f32[64, 64]\"\n",
- " t8 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t8: \"cuda:0 f32[64, 64]\"\n",
- " t9 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t9: \"cuda:0 f32[64, 64]\"\n",
- " t10 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t10: \"cuda:0 f32[64, 64]\"\n",
- " p11 = thunder.distributed.prims.all_gather(t1, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p11: \"FUTURE cuda:0 f32[64, 64]\"\n",
- " t12 = thunder.distributed.prims.wait(p11) # t12: \"cuda:0 f32[64, 64]\"\n",
- " p13 = thunder.distributed.prims.all_gather(t2, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p13: \"FUTURE cuda:0 f32[64, 64]\"\n",
- " t14 = thunder.distributed.prims.wait(p13) # t14: \"cuda:0 f32[64, 64]\"\n",
- " t15 = prims.linear(t0, t12, None) # t15: \"cuda:0 f32[64, 64]\"\n",
- " t16 = prims.tanh(t15) # t16: \"cuda:0 f32[64, 64]\"\n",
- " t17 = prims.linear(t16, t14, None) # t17: \"cuda:0 f32[64, 64]\"\n",
- " t18 = prims.add(t6, t7) # t18: \"cuda:0 f32[64, 64]\"\n",
- " t19 = prims.add(t3, t8) # t19: \"cuda:0 f32[64, 64]\"\n",
- " t20 = prims.add(t5, t9) # t20: \"cuda:0 f32[64, 64]\"\n",
- " t21 = ltorch.reshape(t18, -1, 64) # t21: \"cuda:0 f32[64, 64]\"\n",
- " # t21 = prims.reshape(t18, (64, 64)) # t21: \"cuda:0 f32[64, 64]\"\n",
- " t22 = ltorch.matmul(t21, t14) # t22: \"cuda:0 f32[64, 64]\"\n",
- " # t22 = prims.matmul(t21, t14) # t22: \"cuda:0 f32[64, 64]\"\n",
- " t23 = ltorch.reshape(t18, -1, 64) # t23: \"cuda:0 f32[64, 64]\"\n",
- " # t23 = prims.reshape(t18, (64, 64)) # t23: \"cuda:0 f32[64, 64]\"\n",
- " t24 = prims.transpose(t23, (1, 0)) # t24: \"cuda:0 f32[64, 64]\"\n",
- " t25 = ltorch.reshape(t16, -1, 64) # t25: \"cuda:0 f32[64, 64]\"\n",
- " # t25 = prims.reshape(t16, (64, 64)) # t25: \"cuda:0 f32[64, 64]\"\n",
- " t26 = ltorch.matmul(t24, t25) # t26: \"cuda:0 f32[64, 64]\"\n",
- " # t26 = prims.matmul(t24, t25) # t26: \"cuda:0 f32[64, 64]\"\n",
- " t27 = prims.add(t10, t22) # t27: \"cuda:0 f32[64, 64]\"\n",
- " t28 = prims.add(t20, t26) # t28: \"cuda:0 f32[64, 64]\"\n",
- " t29 = ltorch.mul(t16, t16) # t29: \"cuda:0 f32[64, 64]\"\n",
- " # t29 = prims.mul(t16, t16) # t29: \"cuda:0 f32[64, 64]\"\n",
- " t30 = ltorch.sub(1, t29, alpha=None) # t30: \"cuda:0 f32[64, 64]\"\n",
- " # _ = prims.convert_element_type(1, float)\n",
- " # t30 = prims.sub(1.0, t29) # t30: \"cuda:0 f32[64, 64]\"\n",
- " t31 = ltorch.mul(t27, t30) # t31: \"cuda:0 f32[64, 64]\"\n",
- " # t31 = prims.mul(t27, t30) # t31: \"cuda:0 f32[64, 64]\"\n",
- " t32 = ltorch.reshape(t31, -1, 64) # t32: \"cuda:0 f32[64, 64]\"\n",
- " # t32 = prims.reshape(t31, (64, 64)) # t32: \"cuda:0 f32[64, 64]\"\n",
- " t33 = ltorch.matmul(t32, t12) # t33: \"cuda:0 f32[64, 64]\"\n",
- " # t33 = prims.matmul(t32, t12) # t33: \"cuda:0 f32[64, 64]\"\n",
- " t34 = ltorch.reshape(t31, -1, 64) # t34: \"cuda:0 f32[64, 64]\"\n",
- " # t34 = prims.reshape(t31, (64, 64)) # t34: \"cuda:0 f32[64, 64]\"\n",
- " t35 = prims.transpose(t34, (1, 0)) # t35: \"cuda:0 f32[64, 64]\"\n",
- " t36 = ltorch.reshape(t0, -1, 64) # t36: \"cuda:0 f32[64, 64]\"\n",
- " # t36 = prims.reshape(t0, (64, 64)) # t36: \"cuda:0 f32[64, 64]\"\n",
- " t37 = ltorch.matmul(t35, t36) # t37: \"cuda:0 f32[64, 64]\"\n",
- " # t37 = prims.matmul(t35, t36) # t37: \"cuda:0 f32[64, 64]\"\n",
- " t38 = prims.add(t19, t33) # t38: \"cuda:0 f32[64, 64]\"\n",
- " t39 = prims.add(t4, t37) # t39: \"cuda:0 f32[64, 64]\"\n",
- " t40 = ltorch.true_divide(t28, 2) # t40: \"cuda:0 f32[64, 64]\"\n",
- " # _ = prims.convert_element_type(2, float)\n",
- " # t40 = prims.div(t28, 2.0) # t40: \"cuda:0 f32[64, 64]\"\n",
- " p41 = thunder.distributed.prims.reduce_scatter(t40, _DistributedReduceOps_1, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p41: \"FUTURE cuda:0 f32[32, 64]\"\n",
- " t42 = thunder.distributed.prims.wait(p41) # t42: \"cuda:0 f32[32, 64]\"\n",
- " t43 = ltorch.true_divide(t39, 2) # t43: \"cuda:0 f32[64, 64]\"\n",
- " # _ = prims.convert_element_type(2, float)\n",
- " # t43 = prims.div(t39, 2.0) # t43: \"cuda:0 f32[64, 64]\"\n",
- " p44 = thunder.distributed.prims.reduce_scatter(t43, _DistributedReduceOps_1, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p44: \"FUTURE cuda:0 f32[32, 64]\"\n",
- " t45 = thunder.distributed.prims.wait(p44) # t45: \"cuda:0 f32[32, 64]\"\n",
- " return (({'output': t17, 'flat_args': [t0, t12, t14], 'flat_output': (t17,)}, ((t0, t14, t16), ())), (t38, t45, t42))"
- ]
- },
- "execution_count": 13,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"from thunder.core.transforms import value_and_grad\n",
"\n",
@@ -1151,576 +302,9 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "# Constructed by Delete Last Used (took 0 milliseconds)\n",
- "import torch\n",
- "import torch.nn.functional\n",
- "from thunder.executors.torchex import no_autocast\n",
- "\n",
- "@torch.no_grad()\n",
- "@no_autocast()\n",
- "def _value_and_grad(*args):\n",
- " # args: "Collection" \n",
- " t0, \\\n",
- " t1, \\\n",
- " t2, \\\n",
- " = args\n",
- " del args\n",
- " t3 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32) # t3: "cuda:0 f32[64, 64]"\n",
- " # t3 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32) # t3: "cuda:0 f32[64, 64]"\n",
- " # t3 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32) # t3: "cuda:0 f32[64, 64]"\n",
- " t4 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32) # t4: "cuda:0 f32[64, 64]"\n",
- " # t4 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32) # t4: "cuda:0 f32[64, 64]"\n",
- " # t4 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32) # t4: "cuda:0 f32[64, 64]"\n",
- " t5 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32) # t5: "cuda:0 f32[64, 64]"\n",
- " # t5 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32) # t5: "cuda:0 f32[64, 64]"\n",
- " # t5 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32) # t5: "cuda:0 f32[64, 64]"\n",
- " t6 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32) # t6: "cuda:0 f32[64, 64]"\n",
- " # t6 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32) # t6: "cuda:0 f32[64, 64]"\n",
- " # t6 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32) # t6: "cuda:0 f32[64, 64]"\n",
- " t7 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32) # t7: "cuda:0 f32[64, 64]"\n",
- " # t7 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32) # t7: "cuda:0 f32[64, 64]"\n",
- " # t7 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32) # t7: "cuda:0 f32[64, 64]"\n",
- " t8 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32) # t8: "cuda:0 f32[64, 64]"\n",
- " # t8 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32) # t8: "cuda:0 f32[64, 64]"\n",
- " # t8 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32) # t8: "cuda:0 f32[64, 64]"\n",
- " t9 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32) # t9: "cuda:0 f32[64, 64]"\n",
- " # t9 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32) # t9: "cuda:0 f32[64, 64]"\n",
- " # t9 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32) # t9: "cuda:0 f32[64, 64]"\n",
- " t10 = torch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32) # t10: "cuda:0 f32[64, 64]"\n",
- " # t10 = ltorch.full((64, 64), 1, device=torch.device("cuda:0"), dtype=torch.float32) # t10: "cuda:0 f32[64, 64]"\n",
- " # t10 = prims.full((64, 64), 1, device=devices.Device("cuda:0"), dtype=dtypes.float32) # t10: "cuda:0 f32[64, 64]"\n",
- " p11 = torch_all_gather_prim_impl(t1, _torch_distributed_distributed_c10d_ProcessGroup_2, True) # p11: "FUTURE cuda:0 f32[64, 64]"\n",
- " del t1\n",
- " t12 = torch_wait_prim_impl(p11) # t12: "cuda:0 f32[64, 64]"\n",
- " del p11\n",
- " p13 = torch_all_gather_prim_impl(t2, _torch_distributed_distributed_c10d_ProcessGroup_2, True) # p13: "FUTURE cuda:0 f32[64, 64]"\n",
- " del t2\n",
- " t14 = torch_wait_prim_impl(p13) # t14: "cuda:0 f32[64, 64]"\n",
- " del p13\n",
- " t15 = torch.nn.functional.linear(t0, t12, None) # t15: "cuda:0 f32[64, 64]"\n",
- " # t15 = ltorch.linear(t0, t12, None) # t15: "cuda:0 f32[64, 64]"\n",
- " # t15 = prims.linear(t0, t12, None) # t15: "cuda:0 f32[64, 64]"\n",
- " t16 = torch.tanh(t15) # t16: "cuda:0 f32[64, 64]"\n",
- " # t16 = ltorch.tanh(t15) # t16: "cuda:0 f32[64, 64]"\n",
- " # t16 = prims.tanh(t15) # t16: "cuda:0 f32[64, 64]"\n",
- " del t15\n",
- " t17 = torch.nn.functional.linear(t16, t14, None) # t17: "cuda:0 f32[64, 64]"\n",
- " # t17 = ltorch.linear(t16, t14, None) # t17: "cuda:0 f32[64, 64]"\n",
- " # t17 = prims.linear(t16, t14, None) # t17: "cuda:0 f32[64, 64]"\n",
- " t18 = torch.add(t6, t7) # t18: "cuda:0 f32[64, 64]"\n",
- " # t18 = ltorch.add(t6, t7, alpha=None) # t18: "cuda:0 f32[64, 64]"\n",
- " # t18 = prims.add(t6, t7) # t18: "cuda:0 f32[64, 64]"\n",
- " del t6, t7\n",
- " t19 = torch.add(t3, t8) # t19: "cuda:0 f32[64, 64]"\n",
- " # t19 = ltorch.add(t3, t8, alpha=None) # t19: "cuda:0 f32[64, 64]"\n",
- " # t19 = prims.add(t3, t8) # t19: "cuda:0 f32[64, 64]"\n",
- " del t3, t8\n",
- " t20 = torch.add(t5, t9) # t20: "cuda:0 f32[64, 64]"\n",
- " # t20 = ltorch.add(t5, t9, alpha=None) # t20: "cuda:0 f32[64, 64]"\n",
- " # t20 = prims.add(t5, t9) # t20: "cuda:0 f32[64, 64]"\n",
- " del t5, t9\n",
- " t21 = torch.reshape(t18, (-1, 64)) # t21: "cuda:0 f32[64, 64]"\n",
- " # t21 = ltorch.reshape(t18, (-1, 64)) # t21: "cuda:0 f32[64, 64]"\n",
- " # t21 = prims.reshape(t18, (64, 64)) # t21: "cuda:0 f32[64, 64]"\n",
- " t22 = torch.matmul(t21, t14) # t22: "cuda:0 f32[64, 64]"\n",
- " # t22 = ltorch.matmul(t21, t14) # t22: "cuda:0 f32[64, 64]"\n",
- " # t22 = prims.matmul(t21, t14) # t22: "cuda:0 f32[64, 64]"\n",
- " del t21\n",
- " t23 = torch.reshape(t18, (-1, 64)) # t23: "cuda:0 f32[64, 64]"\n",
- " # t23 = ltorch.reshape(t18, (-1, 64)) # t23: "cuda:0 f32[64, 64]"\n",
- " # t23 = prims.reshape(t18, (64, 64)) # t23: "cuda:0 f32[64, 64]"\n",
- " del t18\n",
- " t24 = torch.permute(t23, (1, 0)) # t24: "cuda:0 f32[64, 64]"\n",
- " # t24 = ltorch.permute(t23, (1, 0)) # t24: "cuda:0 f32[64, 64]"\n",
- " # t24 = prims.transpose(t23, (1, 0)) # t24: "cuda:0 f32[64, 64]"\n",
- " del t23\n",
- " t25 = torch.reshape(t16, (-1, 64)) # t25: "cuda:0 f32[64, 64]"\n",
- " # t25 = ltorch.reshape(t16, (-1, 64)) # t25: "cuda:0 f32[64, 64]"\n",
- " # t25 = prims.reshape(t16, (64, 64)) # t25: "cuda:0 f32[64, 64]"\n",
- " t26 = torch.matmul(t24, t25) # t26: "cuda:0 f32[64, 64]"\n",
- " # t26 = ltorch.matmul(t24, t25) # t26: "cuda:0 f32[64, 64]"\n",
- " # t26 = prims.matmul(t24, t25) # t26: "cuda:0 f32[64, 64]"\n",
- " del t24, t25\n",
- " t27 = torch.add(t10, t22) # t27: "cuda:0 f32[64, 64]"\n",
- " # t27 = ltorch.add(t10, t22, alpha=None) # t27: "cuda:0 f32[64, 64]"\n",
- " # t27 = prims.add(t10, t22) # t27: "cuda:0 f32[64, 64]"\n",
- " del t10, t22\n",
- " t28 = torch.add(t20, t26) # t28: "cuda:0 f32[64, 64]"\n",
- " # t28 = ltorch.add(t20, t26, alpha=None) # t28: "cuda:0 f32[64, 64]"\n",
- " # t28 = prims.add(t20, t26) # t28: "cuda:0 f32[64, 64]"\n",
- " del t20, t26\n",
- " t29 = torch.mul(t16, t16) # t29: "cuda:0 f32[64, 64]"\n",
- " # t29 = ltorch.mul(t16, t16) # t29: "cuda:0 f32[64, 64]"\n",
- " # t29 = prims.mul(t16, t16) # t29: "cuda:0 f32[64, 64]"\n",
- " t30 = torch.sub(1, t29) # t30: "cuda:0 f32[64, 64]"\n",
- " # t30 = ltorch.sub(1, t29, alpha=None) # t30: "cuda:0 f32[64, 64]"\n",
- " # _ = prims.convert_element_type(1, float)\n",
- " # t30 = prims.sub(1.0, t29) # t30: "cuda:0 f32[64, 64]"\n",
- " del t29\n",
- " t31 = torch.mul(t27, t30) # t31: "cuda:0 f32[64, 64]"\n",
- " # t31 = ltorch.mul(t27, t30) # t31: "cuda:0 f32[64, 64]"\n",
- " # t31 = prims.mul(t27, t30) # t31: "cuda:0 f32[64, 64]"\n",
- " del t27, t30\n",
- " t32 = torch.reshape(t31, (-1, 64)) # t32: "cuda:0 f32[64, 64]"\n",
- " # t32 = ltorch.reshape(t31, (-1, 64)) # t32: "cuda:0 f32[64, 64]"\n",
- " # t32 = prims.reshape(t31, (64, 64)) # t32: "cuda:0 f32[64, 64]"\n",
- " t33 = torch.matmul(t32, t12) # t33: "cuda:0 f32[64, 64]"\n",
- " # t33 = ltorch.matmul(t32, t12) # t33: "cuda:0 f32[64, 64]"\n",
- " # t33 = prims.matmul(t32, t12) # t33: "cuda:0 f32[64, 64]"\n",
- " del t32\n",
- " t34 = torch.reshape(t31, (-1, 64)) # t34: "cuda:0 f32[64, 64]"\n",
- " # t34 = ltorch.reshape(t31, (-1, 64)) # t34: "cuda:0 f32[64, 64]"\n",
- " # t34 = prims.reshape(t31, (64, 64)) # t34: "cuda:0 f32[64, 64]"\n",
- " del t31\n",
- " t35 = torch.permute(t34, (1, 0)) # t35: "cuda:0 f32[64, 64]"\n",
- " # t35 = ltorch.permute(t34, (1, 0)) # t35: "cuda:0 f32[64, 64]"\n",
- " # t35 = prims.transpose(t34, (1, 0)) # t35: "cuda:0 f32[64, 64]"\n",
- " del t34\n",
- " t36 = torch.reshape(t0, (-1, 64)) # t36: "cuda:0 f32[64, 64]"\n",
- " # t36 = ltorch.reshape(t0, (-1, 64)) # t36: "cuda:0 f32[64, 64]"\n",
- " # t36 = prims.reshape(t0, (64, 64)) # t36: "cuda:0 f32[64, 64]"\n",
- " t37 = torch.matmul(t35, t36) # t37: "cuda:0 f32[64, 64]"\n",
- " # t37 = ltorch.matmul(t35, t36) # t37: "cuda:0 f32[64, 64]"\n",
- " # t37 = prims.matmul(t35, t36) # t37: "cuda:0 f32[64, 64]"\n",
- " del t35, t36\n",
- " t38 = torch.add(t19, t33) # t38: "cuda:0 f32[64, 64]"\n",
- " # t38 = ltorch.add(t19, t33, alpha=None) # t38: "cuda:0 f32[64, 64]"\n",
- " # t38 = prims.add(t19, t33) # t38: "cuda:0 f32[64, 64]"\n",
- " del t19, t33\n",
- " t39 = torch.add(t4, t37) # t39: "cuda:0 f32[64, 64]"\n",
- " # t39 = ltorch.add(t4, t37, alpha=None) # t39: "cuda:0 f32[64, 64]"\n",
- " # t39 = prims.add(t4, t37) # t39: "cuda:0 f32[64, 64]"\n",
- " del t4, t37\n",
- " t40 = torch.true_divide(t28, 2) # t40: "cuda:0 f32[64, 64]"\n",
- " # t40 = ltorch.true_divide(t28, 2) # t40: "cuda:0 f32[64, 64]"\n",
- " # _ = prims.convert_element_type(2, float)\n",
- " # t40 = prims.div(t28, 2.0) # t40: "cuda:0 f32[64, 64]"\n",
- " del t28\n",
- " p41 = torch_reduce_scatter_prim_impl(t40, _DistributedReduceOps_3, _torch_distributed_distributed_c10d_ProcessGroup_2, True) # p41: "FUTURE cuda:0 f32[32, 64]"\n",
- " del t40\n",
- " t42 = torch_wait_prim_impl(p41) # t42: "cuda:0 f32[32, 64]"\n",
- " del p41\n",
- " t43 = torch.true_divide(t39, 2) # t43: "cuda:0 f32[64, 64]"\n",
- " # t43 = ltorch.true_divide(t39, 2) # t43: "cuda:0 f32[64, 64]"\n",
- " # _ = prims.convert_element_type(2, float)\n",
- " # t43 = prims.div(t39, 2.0) # t43: "cuda:0 f32[64, 64]"\n",
- " del t39\n",
- " p44 = torch_reduce_scatter_prim_impl(t43, _DistributedReduceOps_3, _torch_distributed_distributed_c10d_ProcessGroup_2, True) # p44: "FUTURE cuda:0 f32[32, 64]"\n",
- " del t43\n",
- " t45 = torch_wait_prim_impl(p44) # t45: "cuda:0 f32[32, 64]"\n",
- " del p44\n",
- " return (({'output': t17, 'flat_args': [t0, t12, t14], 'flat_output': (t17,)}, ((t0, t14, t16), ())), (t38, t45, t42))\n",
- "
\n"
- ],
- "text/latex": [
- "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n",
- "\\PY{c+c1}{\\PYZsh{} Constructed by Delete Last Used (took 0 milliseconds)}\n",
- "\\PY{k+kn}{import} \\PY{n+nn}{torch}\n",
- "\\PY{k+kn}{import} \\PY{n+nn}{torch}\\PY{n+nn}{.}\\PY{n+nn}{nn}\\PY{n+nn}{.}\\PY{n+nn}{functional}\n",
- "\\PY{k+kn}{from} \\PY{n+nn}{thunder}\\PY{n+nn}{.}\\PY{n+nn}{executors}\\PY{n+nn}{.}\\PY{n+nn}{torchex} \\PY{k+kn}{import} \\PY{n}{no\\PYZus{}autocast}\n",
- "\n",
- "\\PY{n+nd}{@torch}\\PY{o}{.}\\PY{n}{no\\PYZus{}grad}\\PY{p}{(}\\PY{p}{)}\n",
- "\\PY{n+nd}{@no\\PYZus{}autocast}\\PY{p}{(}\\PY{p}{)}\n",
- "\\PY{k}{def} \\PY{n+nf}{\\PYZus{}value\\PYZus{}and\\PYZus{}grad}\\PY{p}{(}\\PY{o}{*}\\PY{n}{args}\\PY{p}{)}\\PY{p}{:}\n",
- " \\PY{c+c1}{\\PYZsh{} args: \\PYZdq{}Collection\\PYZdq{} }\n",
- " \\PY{n}{t0}\\PY{p}{,} \\PYZbs{}\n",
- " \\PY{n}{t1}\\PY{p}{,} \\PYZbs{}\n",
- " \\PY{n}{t2}\\PY{p}{,} \\PYZbs{}\n",
- " \\PY{o}{=} \\PY{n}{args}\n",
- " \\PY{k}{del} \\PY{n}{args}\n",
- " \\PY{n}{t3} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t3 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t3 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t3: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t4} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t4: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t4 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t4: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t4 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t4: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t5} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t5 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t5 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t5: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t6} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t6: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t6 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t6: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t6 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t6: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t7} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t7: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t7 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t7: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t7 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t7: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t8} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t8: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t8 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t8: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t8 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t8: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t9} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t9: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t9 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t9: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t9 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t9: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t10} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{full}\\PY{p}{(}\\PY{p}{(}\\PY{l+m+mi}{64}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{,} \\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{device}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{device}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{cuda:0}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{dtype}\\PY{o}{=}\\PY{n}{torch}\\PY{o}{.}\\PY{n}{float32}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t10: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t10 = ltorch.full((64, 64), 1, device=torch.device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=torch.float32) \\PYZsh{} t10: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t10 = prims.full((64, 64), 1, device=devices.Device(\\PYZdq{}cuda:0\\PYZdq{}), dtype=dtypes.float32) \\PYZsh{} t10: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{p11} \\PY{o}{=} \\PY{n}{torch\\PYZus{}all\\PYZus{}gather\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{t1}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}2}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p11: \\PYZdq{}FUTURE cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{k}{del} \\PY{n}{t1}\n",
- " \\PY{n}{t12} \\PY{o}{=} \\PY{n}{torch\\PYZus{}wait\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{p11}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t12: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{k}{del} \\PY{n}{p11}\n",
- " \\PY{n}{p13} \\PY{o}{=} \\PY{n}{torch\\PYZus{}all\\PYZus{}gather\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{t2}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}2}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p13: \\PYZdq{}FUTURE cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{k}{del} \\PY{n}{t2}\n",
- " \\PY{n}{t14} \\PY{o}{=} \\PY{n}{torch\\PYZus{}wait\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{p13}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t14: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{k}{del} \\PY{n}{p13}\n",
- " \\PY{n}{t15} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{nn}\\PY{o}{.}\\PY{n}{functional}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t12}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t15: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t15 = ltorch.linear(t0, t12, None) \\PYZsh{} t15: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t15 = prims.linear(t0, t12, None) \\PYZsh{} t15: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t16} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{tanh}\\PY{p}{(}\\PY{n}{t15}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t16: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t16 = ltorch.tanh(t15) \\PYZsh{} t16: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t16 = prims.tanh(t15) \\PYZsh{} t16: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{k}{del} \\PY{n}{t15}\n",
- " \\PY{n}{t17} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{nn}\\PY{o}{.}\\PY{n}{functional}\\PY{o}{.}\\PY{n}{linear}\\PY{p}{(}\\PY{n}{t16}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{,} \\PY{k+kc}{None}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t17: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t17 = ltorch.linear(t16, t14, None) \\PYZsh{} t17: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t17 = prims.linear(t16, t14, None) \\PYZsh{} t17: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t18} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t6}\\PY{p}{,} \\PY{n}{t7}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t18: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t18 = ltorch.add(t6, t7, alpha=None) \\PYZsh{} t18: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t18 = prims.add(t6, t7) \\PYZsh{} t18: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{k}{del} \\PY{n}{t6}\\PY{p}{,} \\PY{n}{t7}\n",
- " \\PY{n}{t19} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t3}\\PY{p}{,} \\PY{n}{t8}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t19: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t19 = ltorch.add(t3, t8, alpha=None) \\PYZsh{} t19: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t19 = prims.add(t3, t8) \\PYZsh{} t19: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{k}{del} \\PY{n}{t3}\\PY{p}{,} \\PY{n}{t8}\n",
- " \\PY{n}{t20} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t5}\\PY{p}{,} \\PY{n}{t9}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t20: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t20 = ltorch.add(t5, t9, alpha=None) \\PYZsh{} t20: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t20 = prims.add(t5, t9) \\PYZsh{} t20: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{k}{del} \\PY{n}{t5}\\PY{p}{,} \\PY{n}{t9}\n",
- " \\PY{n}{t21} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t18}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t21: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t21 = ltorch.reshape(t18, (\\PYZhy{}1, 64)) \\PYZsh{} t21: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t21 = prims.reshape(t18, (64, 64)) \\PYZsh{} t21: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t22} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t21}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t22: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t22 = ltorch.matmul(t21, t14) \\PYZsh{} t22: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t22 = prims.matmul(t21, t14) \\PYZsh{} t22: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{k}{del} \\PY{n}{t21}\n",
- " \\PY{n}{t23} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t18}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t23: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t23 = ltorch.reshape(t18, (\\PYZhy{}1, 64)) \\PYZsh{} t23: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t23 = prims.reshape(t18, (64, 64)) \\PYZsh{} t23: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{k}{del} \\PY{n}{t18}\n",
- " \\PY{n}{t24} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{permute}\\PY{p}{(}\\PY{n}{t23}\\PY{p}{,} \\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t24: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t24 = ltorch.permute(t23, (1, 0)) \\PYZsh{} t24: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t24 = prims.transpose(t23, (1, 0)) \\PYZsh{} t24: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{k}{del} \\PY{n}{t23}\n",
- " \\PY{n}{t25} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t16}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t25: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t25 = ltorch.reshape(t16, (\\PYZhy{}1, 64)) \\PYZsh{} t25: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t25 = prims.reshape(t16, (64, 64)) \\PYZsh{} t25: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t26} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t24}\\PY{p}{,} \\PY{n}{t25}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t26: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t26 = ltorch.matmul(t24, t25) \\PYZsh{} t26: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t26 = prims.matmul(t24, t25) \\PYZsh{} t26: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{k}{del} \\PY{n}{t24}\\PY{p}{,} \\PY{n}{t25}\n",
- " \\PY{n}{t27} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t10}\\PY{p}{,} \\PY{n}{t22}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t27: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t27 = ltorch.add(t10, t22, alpha=None) \\PYZsh{} t27: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t27 = prims.add(t10, t22) \\PYZsh{} t27: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{k}{del} \\PY{n}{t10}\\PY{p}{,} \\PY{n}{t22}\n",
- " \\PY{n}{t28} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t20}\\PY{p}{,} \\PY{n}{t26}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t28: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t28 = ltorch.add(t20, t26, alpha=None) \\PYZsh{} t28: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t28 = prims.add(t20, t26) \\PYZsh{} t28: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{k}{del} \\PY{n}{t20}\\PY{p}{,} \\PY{n}{t26}\n",
- " \\PY{n}{t29} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{mul}\\PY{p}{(}\\PY{n}{t16}\\PY{p}{,} \\PY{n}{t16}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t29: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t29 = ltorch.mul(t16, t16) \\PYZsh{} t29: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t29 = prims.mul(t16, t16) \\PYZsh{} t29: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t30} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{sub}\\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{n}{t29}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t30: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t30 = ltorch.sub(1, t29, alpha=None) \\PYZsh{} t30: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} \\PYZus{} = prims.convert\\PYZus{}element\\PYZus{}type(1, float)}\n",
- " \\PY{c+c1}{\\PYZsh{} t30 = prims.sub(1.0, t29) \\PYZsh{} t30: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{k}{del} \\PY{n}{t29}\n",
- " \\PY{n}{t31} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{mul}\\PY{p}{(}\\PY{n}{t27}\\PY{p}{,} \\PY{n}{t30}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t31: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t31 = ltorch.mul(t27, t30) \\PYZsh{} t31: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t31 = prims.mul(t27, t30) \\PYZsh{} t31: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{k}{del} \\PY{n}{t27}\\PY{p}{,} \\PY{n}{t30}\n",
- " \\PY{n}{t32} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t31}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t32: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t32 = ltorch.reshape(t31, (\\PYZhy{}1, 64)) \\PYZsh{} t32: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t32 = prims.reshape(t31, (64, 64)) \\PYZsh{} t32: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t33} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t32}\\PY{p}{,} \\PY{n}{t12}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t33: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t33 = ltorch.matmul(t32, t12) \\PYZsh{} t33: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t33 = prims.matmul(t32, t12) \\PYZsh{} t33: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{k}{del} \\PY{n}{t32}\n",
- " \\PY{n}{t34} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t31}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t34: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t34 = ltorch.reshape(t31, (\\PYZhy{}1, 64)) \\PYZsh{} t34: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t34 = prims.reshape(t31, (64, 64)) \\PYZsh{} t34: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{k}{del} \\PY{n}{t31}\n",
- " \\PY{n}{t35} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{permute}\\PY{p}{(}\\PY{n}{t34}\\PY{p}{,} \\PY{p}{(}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{0}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t35: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t35 = ltorch.permute(t34, (1, 0)) \\PYZsh{} t35: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t35 = prims.transpose(t34, (1, 0)) \\PYZsh{} t35: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{k}{del} \\PY{n}{t34}\n",
- " \\PY{n}{t36} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{reshape}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{p}{(}\\PY{o}{\\PYZhy{}}\\PY{l+m+mi}{1}\\PY{p}{,} \\PY{l+m+mi}{64}\\PY{p}{)}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t36: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t36 = ltorch.reshape(t0, (\\PYZhy{}1, 64)) \\PYZsh{} t36: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t36 = prims.reshape(t0, (64, 64)) \\PYZsh{} t36: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{n}{t37} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{matmul}\\PY{p}{(}\\PY{n}{t35}\\PY{p}{,} \\PY{n}{t36}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t37: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t37 = ltorch.matmul(t35, t36) \\PYZsh{} t37: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t37 = prims.matmul(t35, t36) \\PYZsh{} t37: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{k}{del} \\PY{n}{t35}\\PY{p}{,} \\PY{n}{t36}\n",
- " \\PY{n}{t38} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t19}\\PY{p}{,} \\PY{n}{t33}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t38: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t38 = ltorch.add(t19, t33, alpha=None) \\PYZsh{} t38: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t38 = prims.add(t19, t33) \\PYZsh{} t38: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{k}{del} \\PY{n}{t19}\\PY{p}{,} \\PY{n}{t33}\n",
- " \\PY{n}{t39} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{add}\\PY{p}{(}\\PY{n}{t4}\\PY{p}{,} \\PY{n}{t37}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t39: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t39 = ltorch.add(t4, t37, alpha=None) \\PYZsh{} t39: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t39 = prims.add(t4, t37) \\PYZsh{} t39: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{k}{del} \\PY{n}{t4}\\PY{p}{,} \\PY{n}{t37}\n",
- " \\PY{n}{t40} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{true\\PYZus{}divide}\\PY{p}{(}\\PY{n}{t28}\\PY{p}{,} \\PY{l+m+mi}{2}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t40: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t40 = ltorch.true\\PYZus{}divide(t28, 2) \\PYZsh{} t40: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} \\PYZus{} = prims.convert\\PYZus{}element\\PYZus{}type(2, float)}\n",
- " \\PY{c+c1}{\\PYZsh{} t40 = prims.div(t28, 2.0) \\PYZsh{} t40: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{k}{del} \\PY{n}{t28}\n",
- " \\PY{n}{p41} \\PY{o}{=} \\PY{n}{torch\\PYZus{}reduce\\PYZus{}scatter\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{t40}\\PY{p}{,} \\PY{n}{\\PYZus{}DistributedReduceOps\\PYZus{}3}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}2}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p41: \\PYZdq{}FUTURE cuda:0 f32[32, 64]\\PYZdq{}}\n",
- " \\PY{k}{del} \\PY{n}{t40}\n",
- " \\PY{n}{t42} \\PY{o}{=} \\PY{n}{torch\\PYZus{}wait\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{p41}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t42: \\PYZdq{}cuda:0 f32[32, 64]\\PYZdq{}}\n",
- " \\PY{k}{del} \\PY{n}{p41}\n",
- " \\PY{n}{t43} \\PY{o}{=} \\PY{n}{torch}\\PY{o}{.}\\PY{n}{true\\PYZus{}divide}\\PY{p}{(}\\PY{n}{t39}\\PY{p}{,} \\PY{l+m+mi}{2}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t43: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} t43 = ltorch.true\\PYZus{}divide(t39, 2) \\PYZsh{} t43: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{c+c1}{\\PYZsh{} \\PYZus{} = prims.convert\\PYZus{}element\\PYZus{}type(2, float)}\n",
- " \\PY{c+c1}{\\PYZsh{} t43 = prims.div(t39, 2.0) \\PYZsh{} t43: \\PYZdq{}cuda:0 f32[64, 64]\\PYZdq{}}\n",
- " \\PY{k}{del} \\PY{n}{t39}\n",
- " \\PY{n}{p44} \\PY{o}{=} \\PY{n}{torch\\PYZus{}reduce\\PYZus{}scatter\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{t43}\\PY{p}{,} \\PY{n}{\\PYZus{}DistributedReduceOps\\PYZus{}3}\\PY{p}{,} \\PY{n}{\\PYZus{}torch\\PYZus{}distributed\\PYZus{}distributed\\PYZus{}c10d\\PYZus{}ProcessGroup\\PYZus{}2}\\PY{p}{,} \\PY{k+kc}{True}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} p44: \\PYZdq{}FUTURE cuda:0 f32[32, 64]\\PYZdq{}}\n",
- " \\PY{k}{del} \\PY{n}{t43}\n",
- " \\PY{n}{t45} \\PY{o}{=} \\PY{n}{torch\\PYZus{}wait\\PYZus{}prim\\PYZus{}impl}\\PY{p}{(}\\PY{n}{p44}\\PY{p}{)} \\PY{c+c1}{\\PYZsh{} t45: \\PYZdq{}cuda:0 f32[32, 64]\\PYZdq{}}\n",
- " \\PY{k}{del} \\PY{n}{p44}\n",
- " \\PY{k}{return} \\PY{p}{(}\\PY{p}{(}\\PY{p}{\\PYZob{}}\\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{n}{t17}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}args}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{[}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t12}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{]}\\PY{p}{,} \\PY{l+s+s1}{\\PYZsq{}}\\PY{l+s+s1}{flat\\PYZus{}output}\\PY{l+s+s1}{\\PYZsq{}}\\PY{p}{:} \\PY{p}{(}\\PY{n}{t17}\\PY{p}{,}\\PY{p}{)}\\PY{p}{\\PYZcb{}}\\PY{p}{,} \\PY{p}{(}\\PY{p}{(}\\PY{n}{t0}\\PY{p}{,} \\PY{n}{t14}\\PY{p}{,} \\PY{n}{t16}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{p}{(}\\PY{n}{t38}\\PY{p}{,} \\PY{n}{t45}\\PY{p}{,} \\PY{n}{t42}\\PY{p}{)}\\PY{p}{)}\n",
- "\\end{Verbatim}\n"
- ],
- "text/plain": [
- "# Constructed by Delete Last Used (took 0 milliseconds)\n",
- "import torch\n",
- "import torch.nn.functional\n",
- "from thunder.executors.torchex import no_autocast\n",
- "\n",
- "@torch.no_grad()\n",
- "@no_autocast()\n",
- "def _value_and_grad(*args):\n",
- " # args: \"Collection\" \n",
- " t0, \\\n",
- " t1, \\\n",
- " t2, \\\n",
- " = args\n",
- " del args\n",
- " t3 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t3: \"cuda:0 f32[64, 64]\"\n",
- " # t3 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t3: \"cuda:0 f32[64, 64]\"\n",
- " # t3 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t3: \"cuda:0 f32[64, 64]\"\n",
- " t4 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t4: \"cuda:0 f32[64, 64]\"\n",
- " # t4 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t4: \"cuda:0 f32[64, 64]\"\n",
- " # t4 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t4: \"cuda:0 f32[64, 64]\"\n",
- " t5 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t5: \"cuda:0 f32[64, 64]\"\n",
- " # t5 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t5: \"cuda:0 f32[64, 64]\"\n",
- " # t5 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t5: \"cuda:0 f32[64, 64]\"\n",
- " t6 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t6: \"cuda:0 f32[64, 64]\"\n",
- " # t6 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t6: \"cuda:0 f32[64, 64]\"\n",
- " # t6 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t6: \"cuda:0 f32[64, 64]\"\n",
- " t7 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t7: \"cuda:0 f32[64, 64]\"\n",
- " # t7 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t7: \"cuda:0 f32[64, 64]\"\n",
- " # t7 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t7: \"cuda:0 f32[64, 64]\"\n",
- " t8 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t8: \"cuda:0 f32[64, 64]\"\n",
- " # t8 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t8: \"cuda:0 f32[64, 64]\"\n",
- " # t8 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t8: \"cuda:0 f32[64, 64]\"\n",
- " t9 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t9: \"cuda:0 f32[64, 64]\"\n",
- " # t9 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t9: \"cuda:0 f32[64, 64]\"\n",
- " # t9 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t9: \"cuda:0 f32[64, 64]\"\n",
- " t10 = torch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t10: \"cuda:0 f32[64, 64]\"\n",
- " # t10 = ltorch.full((64, 64), 1, device=torch.device(\"cuda:0\"), dtype=torch.float32) # t10: \"cuda:0 f32[64, 64]\"\n",
- " # t10 = prims.full((64, 64), 1, device=devices.Device(\"cuda:0\"), dtype=dtypes.float32) # t10: \"cuda:0 f32[64, 64]\"\n",
- " p11 = torch_all_gather_prim_impl(t1, _torch_distributed_distributed_c10d_ProcessGroup_2, True) # p11: \"FUTURE cuda:0 f32[64, 64]\"\n",
- " del t1\n",
- " t12 = torch_wait_prim_impl(p11) # t12: \"cuda:0 f32[64, 64]\"\n",
- " del p11\n",
- " p13 = torch_all_gather_prim_impl(t2, _torch_distributed_distributed_c10d_ProcessGroup_2, True) # p13: \"FUTURE cuda:0 f32[64, 64]\"\n",
- " del t2\n",
- " t14 = torch_wait_prim_impl(p13) # t14: \"cuda:0 f32[64, 64]\"\n",
- " del p13\n",
- " t15 = torch.nn.functional.linear(t0, t12, None) # t15: \"cuda:0 f32[64, 64]\"\n",
- " # t15 = ltorch.linear(t0, t12, None) # t15: \"cuda:0 f32[64, 64]\"\n",
- " # t15 = prims.linear(t0, t12, None) # t15: \"cuda:0 f32[64, 64]\"\n",
- " t16 = torch.tanh(t15) # t16: \"cuda:0 f32[64, 64]\"\n",
- " # t16 = ltorch.tanh(t15) # t16: \"cuda:0 f32[64, 64]\"\n",
- " # t16 = prims.tanh(t15) # t16: \"cuda:0 f32[64, 64]\"\n",
- " del t15\n",
- " t17 = torch.nn.functional.linear(t16, t14, None) # t17: \"cuda:0 f32[64, 64]\"\n",
- " # t17 = ltorch.linear(t16, t14, None) # t17: \"cuda:0 f32[64, 64]\"\n",
- " # t17 = prims.linear(t16, t14, None) # t17: \"cuda:0 f32[64, 64]\"\n",
- " t18 = torch.add(t6, t7) # t18: \"cuda:0 f32[64, 64]\"\n",
- " # t18 = ltorch.add(t6, t7, alpha=None) # t18: \"cuda:0 f32[64, 64]\"\n",
- " # t18 = prims.add(t6, t7) # t18: \"cuda:0 f32[64, 64]\"\n",
- " del t6, t7\n",
- " t19 = torch.add(t3, t8) # t19: \"cuda:0 f32[64, 64]\"\n",
- " # t19 = ltorch.add(t3, t8, alpha=None) # t19: \"cuda:0 f32[64, 64]\"\n",
- " # t19 = prims.add(t3, t8) # t19: \"cuda:0 f32[64, 64]\"\n",
- " del t3, t8\n",
- " t20 = torch.add(t5, t9) # t20: \"cuda:0 f32[64, 64]\"\n",
- " # t20 = ltorch.add(t5, t9, alpha=None) # t20: \"cuda:0 f32[64, 64]\"\n",
- " # t20 = prims.add(t5, t9) # t20: \"cuda:0 f32[64, 64]\"\n",
- " del t5, t9\n",
- " t21 = torch.reshape(t18, (-1, 64)) # t21: \"cuda:0 f32[64, 64]\"\n",
- " # t21 = ltorch.reshape(t18, (-1, 64)) # t21: \"cuda:0 f32[64, 64]\"\n",
- " # t21 = prims.reshape(t18, (64, 64)) # t21: \"cuda:0 f32[64, 64]\"\n",
- " t22 = torch.matmul(t21, t14) # t22: \"cuda:0 f32[64, 64]\"\n",
- " # t22 = ltorch.matmul(t21, t14) # t22: \"cuda:0 f32[64, 64]\"\n",
- " # t22 = prims.matmul(t21, t14) # t22: \"cuda:0 f32[64, 64]\"\n",
- " del t21\n",
- " t23 = torch.reshape(t18, (-1, 64)) # t23: \"cuda:0 f32[64, 64]\"\n",
- " # t23 = ltorch.reshape(t18, (-1, 64)) # t23: \"cuda:0 f32[64, 64]\"\n",
- " # t23 = prims.reshape(t18, (64, 64)) # t23: \"cuda:0 f32[64, 64]\"\n",
- " del t18\n",
- " t24 = torch.permute(t23, (1, 0)) # t24: \"cuda:0 f32[64, 64]\"\n",
- " # t24 = ltorch.permute(t23, (1, 0)) # t24: \"cuda:0 f32[64, 64]\"\n",
- " # t24 = prims.transpose(t23, (1, 0)) # t24: \"cuda:0 f32[64, 64]\"\n",
- " del t23\n",
- " t25 = torch.reshape(t16, (-1, 64)) # t25: \"cuda:0 f32[64, 64]\"\n",
- " # t25 = ltorch.reshape(t16, (-1, 64)) # t25: \"cuda:0 f32[64, 64]\"\n",
- " # t25 = prims.reshape(t16, (64, 64)) # t25: \"cuda:0 f32[64, 64]\"\n",
- " t26 = torch.matmul(t24, t25) # t26: \"cuda:0 f32[64, 64]\"\n",
- " # t26 = ltorch.matmul(t24, t25) # t26: \"cuda:0 f32[64, 64]\"\n",
- " # t26 = prims.matmul(t24, t25) # t26: \"cuda:0 f32[64, 64]\"\n",
- " del t24, t25\n",
- " t27 = torch.add(t10, t22) # t27: \"cuda:0 f32[64, 64]\"\n",
- " # t27 = ltorch.add(t10, t22, alpha=None) # t27: \"cuda:0 f32[64, 64]\"\n",
- " # t27 = prims.add(t10, t22) # t27: \"cuda:0 f32[64, 64]\"\n",
- " del t10, t22\n",
- " t28 = torch.add(t20, t26) # t28: \"cuda:0 f32[64, 64]\"\n",
- " # t28 = ltorch.add(t20, t26, alpha=None) # t28: \"cuda:0 f32[64, 64]\"\n",
- " # t28 = prims.add(t20, t26) # t28: \"cuda:0 f32[64, 64]\"\n",
- " del t20, t26\n",
- " t29 = torch.mul(t16, t16) # t29: \"cuda:0 f32[64, 64]\"\n",
- " # t29 = ltorch.mul(t16, t16) # t29: \"cuda:0 f32[64, 64]\"\n",
- " # t29 = prims.mul(t16, t16) # t29: \"cuda:0 f32[64, 64]\"\n",
- " t30 = torch.sub(1, t29) # t30: \"cuda:0 f32[64, 64]\"\n",
- " # t30 = ltorch.sub(1, t29, alpha=None) # t30: \"cuda:0 f32[64, 64]\"\n",
- " # _ = prims.convert_element_type(1, float)\n",
- " # t30 = prims.sub(1.0, t29) # t30: \"cuda:0 f32[64, 64]\"\n",
- " del t29\n",
- " t31 = torch.mul(t27, t30) # t31: \"cuda:0 f32[64, 64]\"\n",
- " # t31 = ltorch.mul(t27, t30) # t31: \"cuda:0 f32[64, 64]\"\n",
- " # t31 = prims.mul(t27, t30) # t31: \"cuda:0 f32[64, 64]\"\n",
- " del t27, t30\n",
- " t32 = torch.reshape(t31, (-1, 64)) # t32: \"cuda:0 f32[64, 64]\"\n",
- " # t32 = ltorch.reshape(t31, (-1, 64)) # t32: \"cuda:0 f32[64, 64]\"\n",
- " # t32 = prims.reshape(t31, (64, 64)) # t32: \"cuda:0 f32[64, 64]\"\n",
- " t33 = torch.matmul(t32, t12) # t33: \"cuda:0 f32[64, 64]\"\n",
- " # t33 = ltorch.matmul(t32, t12) # t33: \"cuda:0 f32[64, 64]\"\n",
- " # t33 = prims.matmul(t32, t12) # t33: \"cuda:0 f32[64, 64]\"\n",
- " del t32\n",
- " t34 = torch.reshape(t31, (-1, 64)) # t34: \"cuda:0 f32[64, 64]\"\n",
- " # t34 = ltorch.reshape(t31, (-1, 64)) # t34: \"cuda:0 f32[64, 64]\"\n",
- " # t34 = prims.reshape(t31, (64, 64)) # t34: \"cuda:0 f32[64, 64]\"\n",
- " del t31\n",
- " t35 = torch.permute(t34, (1, 0)) # t35: \"cuda:0 f32[64, 64]\"\n",
- " # t35 = ltorch.permute(t34, (1, 0)) # t35: \"cuda:0 f32[64, 64]\"\n",
- " # t35 = prims.transpose(t34, (1, 0)) # t35: \"cuda:0 f32[64, 64]\"\n",
- " del t34\n",
- " t36 = torch.reshape(t0, (-1, 64)) # t36: \"cuda:0 f32[64, 64]\"\n",
- " # t36 = ltorch.reshape(t0, (-1, 64)) # t36: \"cuda:0 f32[64, 64]\"\n",
- " # t36 = prims.reshape(t0, (64, 64)) # t36: \"cuda:0 f32[64, 64]\"\n",
- " t37 = torch.matmul(t35, t36) # t37: \"cuda:0 f32[64, 64]\"\n",
- " # t37 = ltorch.matmul(t35, t36) # t37: \"cuda:0 f32[64, 64]\"\n",
- " # t37 = prims.matmul(t35, t36) # t37: \"cuda:0 f32[64, 64]\"\n",
- " del t35, t36\n",
- " t38 = torch.add(t19, t33) # t38: \"cuda:0 f32[64, 64]\"\n",
- " # t38 = ltorch.add(t19, t33, alpha=None) # t38: \"cuda:0 f32[64, 64]\"\n",
- " # t38 = prims.add(t19, t33) # t38: \"cuda:0 f32[64, 64]\"\n",
- " del t19, t33\n",
- " t39 = torch.add(t4, t37) # t39: \"cuda:0 f32[64, 64]\"\n",
- " # t39 = ltorch.add(t4, t37, alpha=None) # t39: \"cuda:0 f32[64, 64]\"\n",
- " # t39 = prims.add(t4, t37) # t39: \"cuda:0 f32[64, 64]\"\n",
- " del t4, t37\n",
- " t40 = torch.true_divide(t28, 2) # t40: \"cuda:0 f32[64, 64]\"\n",
- " # t40 = ltorch.true_divide(t28, 2) # t40: \"cuda:0 f32[64, 64]\"\n",
- " # _ = prims.convert_element_type(2, float)\n",
- " # t40 = prims.div(t28, 2.0) # t40: \"cuda:0 f32[64, 64]\"\n",
- " del t28\n",
- " p41 = torch_reduce_scatter_prim_impl(t40, _DistributedReduceOps_3, _torch_distributed_distributed_c10d_ProcessGroup_2, True) # p41: \"FUTURE cuda:0 f32[32, 64]\"\n",
- " del t40\n",
- " t42 = torch_wait_prim_impl(p41) # t42: \"cuda:0 f32[32, 64]\"\n",
- " del p41\n",
- " t43 = torch.true_divide(t39, 2) # t43: \"cuda:0 f32[64, 64]\"\n",
- " # t43 = ltorch.true_divide(t39, 2) # t43: \"cuda:0 f32[64, 64]\"\n",
- " # _ = prims.convert_element_type(2, float)\n",
- " # t43 = prims.div(t39, 2.0) # t43: \"cuda:0 f32[64, 64]\"\n",
- " del t39\n",
- " p44 = torch_reduce_scatter_prim_impl(t43, _DistributedReduceOps_3, _torch_distributed_distributed_c10d_ProcessGroup_2, True) # p44: \"FUTURE cuda:0 f32[32, 64]\"\n",
- " del t43\n",
- " t45 = torch_wait_prim_impl(p44) # t45: \"cuda:0 f32[32, 64]\"\n",
- " del p44\n",
- " return (({'output': t17, 'flat_args': [t0, t12, t14], 'flat_output': (t17,)}, ((t0, t14, t16), ())), (t38, t45, t42))"
- ]
- },
- "execution_count": 14,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"optimized_trace = thunder.transform_for_execution(forward_backward_trace, executors_list=thunder.get_always_executors())\n",
"\n",
@@ -1749,17 +333,9 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Overwriting thunder_fsdp_simple_example.py\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"%%writefile thunder_fsdp_simple_example.py\n",
"\n",
@@ -1841,148 +417,9 @@
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "W0314 08:26:39.130000 140292199276608 torch/distributed/run.py:757] \n",
- "W0314 08:26:39.130000 140292199276608 torch/distributed/run.py:757] *****************************************\n",
- "W0314 08:26:39.130000 140292199276608 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. \n",
- "W0314 08:26:39.130000 140292199276608 torch/distributed/run.py:757] *****************************************\n",
- "# Constructed by Delete Last Used (took 0 milliseconds)\n",
- "import torch\n",
- "import torch.nn.functional\n",
- "from thunder.executors.torchex import no_autocast\n",
- "\n",
- "@torch.no_grad()\n",
- "@no_autocast()\n",
- "def augmented_forward_fn(input, t_0_bias, t_2_bias, t_0_weight, t_2_weight):\n",
- " # input: \"cuda:0 f32[64, 64]\" \n",
- " # t_0_bias: \"cuda:0 f32[32]\" \n",
- " p0 = torch_all_gather_prim_impl(t_0_bias, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p0: \"FUTURE cuda:0 f32[64]\"\n",
- " # t_2_bias: \"cuda:0 f32[32]\" \n",
- " p2 = torch_all_gather_prim_impl(t_2_bias, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p2: \"FUTURE cuda:0 f32[64]\"\n",
- " # t_0_weight: \"cuda:0 f32[32, 64]\" \n",
- " p4 = torch_all_gather_prim_impl(t_0_weight, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p4: \"FUTURE cuda:0 f32[64, 64]\"\n",
- " # t_2_weight: \"cuda:0 f32[32, 64]\" \n",
- " p9 = torch_all_gather_prim_impl(t_2_weight, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p9: \"FUTURE cuda:0 f32[64, 64]\"\n",
- " t1 = torch_wait_prim_impl(p0) # t1: \"cuda:0 f32[64]\"\n",
- " del p0\n",
- " t3 = torch_wait_prim_impl(p2) # t3: \"cuda:0 f32[64]\"\n",
- " del p2\n",
- " t5 = torch_wait_prim_impl(p4) # t5: \"cuda:0 f32[64, 64]\"\n",
- " del p4\n",
- " t6 = torch.nn.functional.linear(input, t5, t1) # t6: \"cuda:0 f32[64, 64]\"\n",
- " # t6 = ltorch.linear(input, t5, t1) # t6: \"cuda:0 f32[64, 64]\"\n",
- " # t6 = prims.linear(input, t5, t1) # t6: \"cuda:0 f32[64, 64]\"\n",
- " del t5, t1\n",
- " [t7, t8] = nvFusion0(t6)\n",
- " # t7 = prims.gt(t6, 0.0) # t7: \"cuda:0 b8[64, 64]\"\n",
- " # t8 = prims.where(t7, t6, 0.0) # t8: \"cuda:0 f32[64, 64]\"\n",
- " del t6\n",
- " t10 = torch_wait_prim_impl(p9) # t10: \"cuda:0 f32[64, 64]\"\n",
- " del p9\n",
- " t11 = torch.nn.functional.linear(t8, t10, t3) # t11: \"cuda:0 f32[64, 64]\"\n",
- " # t11 = ltorch.linear(t8, t10, t3) # t11: \"cuda:0 f32[64, 64]\"\n",
- " # t11 = prims.linear(t8, t10, t3) # t11: \"cuda:0 f32[64, 64]\"\n",
- " del t3\n",
- " return {'output': t11, 'flat_args': [input, t_0_bias, t_2_bias, t_0_weight, t_2_weight], 'flat_output': (t11,)}, ((input, t10, t7, t8), ())\n",
- "********************************************************\n",
- "# Constructed by Delete Last Used (took 0 milliseconds)\n",
- "import torch\n",
- "from thunder.executors.torchex import no_autocast\n",
- "\n",
- "@torch.no_grad()\n",
- "@no_autocast()\n",
- "def backward_fn(saved_for_backward, cotangents):\n",
- " # saved_for_backward: \"Collection\" \n",
- " # cotangents: \"Collection\" \n",
- " C0, \\\n",
- " _, \\\n",
- " = saved_for_backward\n",
- " clear_collection(saved_for_backward)\n",
- " del saved_for_backward\n",
- " t0, \\\n",
- " = cotangents\n",
- " clear_collection(cotangents)\n",
- " del cotangents\n",
- " input, \\\n",
- " t10, \\\n",
- " t7, \\\n",
- " t8, \\\n",
- " = C0\n",
- " clear_collection(C0)\n",
- " del C0\n",
- " t31 = torch.reshape(t0, (-1, 64)) # t31: \"cuda:0 f32[64, 64]\"\n",
- " # t31 = ltorch.reshape(t0, (-1, 64)) # t31: \"cuda:0 f32[64, 64]\"\n",
- " # t31 = prims.reshape(t0, (64, 64)) # t31: \"cuda:0 f32[64, 64]\"\n",
- " t32 = torch.permute(t31, (1, 0)) # t32: \"cuda:0 f32[64, 64]\"\n",
- " # t32 = ltorch.permute(t31, (1, 0)) # t32: \"cuda:0 f32[64, 64]\"\n",
- " # t32 = prims.transpose(t31, (1, 0)) # t32: \"cuda:0 f32[64, 64]\"\n",
- " t33 = torch.reshape(t8, (-1, 64)) # t33: \"cuda:0 f32[64, 64]\"\n",
- " # t33 = ltorch.reshape(t8, (-1, 64)) # t33: \"cuda:0 f32[64, 64]\"\n",
- " # t33 = prims.reshape(t8, (64, 64)) # t33: \"cuda:0 f32[64, 64]\"\n",
- " del t8\n",
- " t45 = torch.reshape(input, (-1, 64)) # t45: \"cuda:0 f32[64, 64]\"\n",
- " # t45 = ltorch.reshape(input, (-1, 64)) # t45: \"cuda:0 f32[64, 64]\"\n",
- " # t45 = prims.reshape(input, (64, 64)) # t45: \"cuda:0 f32[64, 64]\"\n",
- " del input\n",
- " [t51] = nvFusion0(t0)\n",
- " # t35 = prims.sum(t0, (0,)) # t35: \"cuda:0 f32[64]\"\n",
- " # t51 = prims.div(t35, 2.0) # t51: \"cuda:0 f32[64]\"\n",
- " del t0\n",
- " p52 = torch_reduce_scatter_prim_impl(t51, _DistributedReduceOps_0, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p52: \"FUTURE cuda:0 f32[32]\"\n",
- " del t51\n",
- " t30 = torch.matmul(t31, t10) # t30: \"cuda:0 f32[64, 64]\"\n",
- " # t30 = ltorch.matmul(t29, t10) # t30: \"cuda:0 f32[64, 64]\"\n",
- " # t30 = prims.matmul(t29, t10) # t30: \"cuda:0 f32[64, 64]\"\n",
- " del t31, t10\n",
- " t34 = torch.matmul(t32, t33) # t34: \"cuda:0 f32[64, 64]\"\n",
- " # t34 = ltorch.matmul(t32, t33) # t34: \"cuda:0 f32[64, 64]\"\n",
- " # t34 = prims.matmul(t32, t33) # t34: \"cuda:0 f32[64, 64]\"\n",
- " del t32, t33\n",
- " [t36, t39, t54] = nvFusion1(t30, t34, t7)\n",
- " # t39 = prims.where(t7, t30, 0.0) # t39: \"cuda:0 f32[64, 64]\"\n",
- " # t47 = prims.sum(t39, (0,)) # t47: \"cuda:0 f32[64]\"\n",
- " # t54 = prims.div(t47, 2.0) # t54: \"cuda:0 f32[64]\"\n",
- " # t36 = prims.div(t34, 2.0) # t36: \"cuda:0 f32[64, 64]\"\n",
- " del t30, t34, t7\n",
- " p37 = torch_reduce_scatter_prim_impl(t36, _DistributedReduceOps_0, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p37: \"FUTURE cuda:0 f32[32, 64]\"\n",
- " del t36\n",
- " p55 = torch_reduce_scatter_prim_impl(t54, _DistributedReduceOps_0, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p55: \"FUTURE cuda:0 f32[32]\"\n",
- " del t54\n",
- " t43 = torch.reshape(t39, (-1, 64)) # t43: \"cuda:0 f32[64, 64]\"\n",
- " # t43 = ltorch.reshape(t39, (-1, 64)) # t43: \"cuda:0 f32[64, 64]\"\n",
- " # t43 = prims.reshape(t39, (64, 64)) # t43: \"cuda:0 f32[64, 64]\"\n",
- " del t39\n",
- " t44 = torch.permute(t43, (1, 0)) # t44: \"cuda:0 f32[64, 64]\"\n",
- " # t44 = ltorch.permute(t43, (1, 0)) # t44: \"cuda:0 f32[64, 64]\"\n",
- " # t44 = prims.transpose(t43, (1, 0)) # t44: \"cuda:0 f32[64, 64]\"\n",
- " del t43\n",
- " t46 = torch.matmul(t44, t45) # t46: \"cuda:0 f32[64, 64]\"\n",
- " # t46 = ltorch.matmul(t44, t45) # t46: \"cuda:0 f32[64, 64]\"\n",
- " # t46 = prims.matmul(t44, t45) # t46: \"cuda:0 f32[64, 64]\"\n",
- " del t44, t45\n",
- " [t48] = nvFusion2(t46)\n",
- " # t48 = prims.div(t46, 2.0) # t48: \"cuda:0 f32[64, 64]\"\n",
- " del t46\n",
- " p49 = torch_reduce_scatter_prim_impl(t48, _DistributedReduceOps_0, _torch_distributed_distributed_c10d_ProcessGroup_1, True) # p49: \"FUTURE cuda:0 f32[32, 64]\"\n",
- " del t48\n",
- " t53 = torch_wait_prim_impl(p52) # t53: \"cuda:0 f32[32]\"\n",
- " del p52\n",
- " t38 = torch_wait_prim_impl(p37) # t38: \"cuda:0 f32[32, 64]\"\n",
- " del p37\n",
- " t56 = torch_wait_prim_impl(p55) # t56: \"cuda:0 f32[32]\"\n",
- " del p55\n",
- " t50 = torch_wait_prim_impl(p49) # t50: \"cuda:0 f32[32, 64]\"\n",
- " del p49\n",
- " return (None, t56, t53, t50, t38)\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"!torchrun --nproc_per_node=2 thunder_fsdp_simple_example.py"
]
@@ -2017,7 +454,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.7"
+ "version": "3.10.13"
}
},
"nbformat": 4,
diff --git a/requirements/docs.txt b/requirements/docs.txt
index 14615a08c4..69547efb64 100644
--- a/requirements/docs.txt
+++ b/requirements/docs.txt
@@ -7,8 +7,10 @@ docutils >=0.16
sphinxcontrib-fulltoc ==1.2.0
sphinxcontrib-mockautodoc
-pt-lightning-sphinx-theme @ https://github.com/Lightning-AI/lightning_sphinx_theme/archive/master.zip
sphinx-autodoc-typehints ==1.23.0
sphinx-paramlinks ==0.6.0
sphinx-togglebutton ==0.3.2
sphinx-copybutton ==0.5.2
+
+# installed from S3 location and fetched in advance
+lai-sphinx-theme
diff --git a/thunder/__init__.py b/thunder/__init__.py
index 21ba6499ec..3386590a1f 100644
--- a/thunder/__init__.py
+++ b/thunder/__init__.py
@@ -44,7 +44,7 @@
from thunder.core.compile_data import compile_data_and_stats
from thunder.core.langctxs import LanguageContext
import thunder.core.langctxs as langctxs
-from thunder.core.baseutils import run_once
+from thunder.core.baseutils import run_once, check
from thunder.core.proxies import (
Proxy,
TensorProxy,
@@ -563,25 +563,38 @@ def get_computation_and_inputs(*args, **kwargs):
# thunder_backward may recursively call compile and wraps the result in a
# torch.autograd.Function to support embedding of Thunder-compiled
# functions in torch's Autograd
+
+ # Currently split_forward_backward also includes
+ # transform_for_execution and various sorting of symbols,
+ # applying transform_for_execution after this would be
+ # breaking the order of operations
computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps)
# Note computation_trc and backward_trc have been appended to cs.last_(backward_)traces
# by split_forward_backward
-
- cs.last_computation_transformation_start = time.time_ns()
-
- ## EPILOGUE and TRANSFORMS should not mix...
- # applies transforms
- for transform in additional_transforms:
- computation_trc = transform(computation_trc, executors_list=cd.executors_list)
- computation_traces.append(computation_trc)
-
- with langctxs.langctx(cd.langctx):
- extraces = transform_for_execution(
- computation_trc,
- executors_list=cd.executors_list,
- )
- extrace = extraces[-1]
- comp = extrace.python_callable()
+ extraces = cs.last_traces
+ check(
+ not additional_transforms,
+ lambda: "Specifying additional_transforms is not supported with PyTorch Autograd integration",
+ )
+
+ if backward_trc is None:
+ cs.last_computation_transformation_start = time.time_ns()
+
+ ## EPILOGUE and TRANSFORMS should not mix...
+ # applies transforms
+ for transform in additional_transforms:
+ computation_trc = transform(computation_trc, executors_list=cd.executors_list)
+ computation_traces.append(computation_trc)
+
+ with langctxs.langctx(cd.langctx):
+ extraces = transform_for_execution(
+ computation_trc,
+ executors_list=cd.executors_list,
+ )
+ computation_trc = extraces[-1]
+ cs.last_computation_transformation_stop = time.time_ns()
+
+ comp = computation_trc.python_callable()
if backward_trc is not None:
backward_fn = backward_trc.python_callable()
@@ -595,7 +608,6 @@ def get_computation_and_inputs(*args, **kwargs):
if cd.cache_option is not CACHE_OPTIONS.NO_CACHING:
cs.interpreter_cache.append(cache_entry)
- cs.last_computation_transformation_stop = time.time_ns()
cs.last_traces += extraces
cs.last_prologue_traces = [prologue_trc] + protraces
cs.last_prologue = pro
diff --git a/thunder/benchmarks/distributed.py b/thunder/benchmarks/distributed.py
index 3437c47e30..bc02528fe7 100644
--- a/thunder/benchmarks/distributed.py
+++ b/thunder/benchmarks/distributed.py
@@ -322,9 +322,11 @@ def parse_args() -> argparse.Namespace:
ResultFormatter(
model_name=args.model,
base_name="torch_fsdp",
- suffix=str(sharding_strategy).lower() + "-bucketing_" + "block"
- if auto_wrap_policy is not None
- else "none",
+ suffix=(
+ str(sharding_strategy).lower() + "-bucketing_" + "block"
+ if auto_wrap_policy is not None
+ else "none"
+ ),
dtype=args.dtype,
world_size=world_size,
total_callable_construction_time=total_cct,
@@ -352,9 +354,11 @@ def parse_args() -> argparse.Namespace:
ResultFormatter(
model_name=args.model,
base_name="torch_compile_fsdp",
- suffix=str(sharding_strategy).lower() + "-bucketing_" + "block"
- if auto_wrap_policy is not None
- else "none",
+ suffix=(
+ str(sharding_strategy).lower() + "-bucketing_" + "block"
+ if auto_wrap_policy is not None
+ else "none"
+ ),
dtype=args.dtype,
world_size=world_size,
total_callable_construction_time=total_cct,
diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py
index 6c0a6ffc0c..bbba0a08cb 100644
--- a/thunder/core/jit_ext.py
+++ b/thunder/core/jit_ext.py
@@ -377,7 +377,9 @@ class ThunderSharpEdgeError(RuntimeError):
def _sharp_edge(desc: str, value: Any, /) -> Any | INTERPRETER_SIGNALS:
sharp_edges: SHARP_EDGES_OPTIONS = get_minimal_ctx().sharp_edges
- s: str = f"{desc} is a sharp edge that cannot be translated to a thunder program unless using interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON."
+ s: str = (
+ f"{desc} is a sharp edge that cannot be translated to a thunder program unless using interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON."
+ )
if sharp_edges is SHARP_EDGES_OPTIONS.ERROR:
return do_raise(ThunderSharpEdgeError(s))
@@ -469,7 +471,9 @@ class JITSharpEdgeError(RuntimeError):
def _general_jit_sharp_edge(desc: str, value: Any, /) -> Any | INTERPRETER_SIGNALS:
sharp_edges: SHARP_EDGES_OPTIONS = get_minimal_ctx().sharp_edges
- s: str = f"{desc} This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!"
+ s: str = (
+ f"{desc} This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!"
+ )
if sharp_edges is SHARP_EDGES_OPTIONS.ERROR:
return do_raise(JITSharpEdgeError(s))
diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py
index 772e65a84d..3bd25c3297 100644
--- a/thunder/core/transforms.py
+++ b/thunder/core/transforms.py
@@ -3802,7 +3802,7 @@ def augmented_forward_fn(*args, **kwargs):
# Copy the signature of the original function so that the arguments are
# named correctly in the augmented forward pass instead of being named
# "args" and "kwargs".
- augmented_forward_fn.__signature__ = inspect.signature(trace.fn)
+ augmented_forward_fn.__signature__ = inspect.signature(trace.fn or trace.python_callable())
def ones_like(x):
if isinstance(x, TensorProxy):
diff --git a/thunder/core/utils.py b/thunder/core/utils.py
index 1edebe70e4..f0d4764fc4 100644
--- a/thunder/core/utils.py
+++ b/thunder/core/utils.py
@@ -743,16 +743,13 @@ class FrozenDict(_UserDictT[T, T1], Mapping[T, T1]):
"""
@overload
- def __init__(self, data: Mapping[T, T1]) -> None:
- ...
+ def __init__(self, data: Mapping[T, T1]) -> None: ...
@overload
- def __init__(self, data: Iterable[T, T1]) -> None:
- ...
+ def __init__(self, data: Iterable[T, T1]) -> None: ...
@overload
- def __init__(self, *args: Any, **kwargs: Any) -> None:
- ...
+ def __init__(self, *args: Any, **kwargs: Any) -> None: ...
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
@@ -834,8 +831,7 @@ def _safe_zip_gen(*args):
@overload
-def safe_zip(x: Iterable[T], y: Iterable[T1], /) -> Iterable[tuple[T, T1]]:
- ...
+def safe_zip(x: Iterable[T], y: Iterable[T1], /) -> Iterable[tuple[T, T1]]: ...
def safe_zip(*args):
diff --git a/thunder/core/vjp_utils.py b/thunder/core/vjp_utils.py
index 3c8128ffeb..8910dc3110 100644
--- a/thunder/core/vjp_utils.py
+++ b/thunder/core/vjp_utils.py
@@ -1,11 +1,12 @@
import inspect
+from collections.abc import Callable
+from functools import wraps
from inspect import Parameter, Signature
from itertools import chain
-from collections.abc import Callable
from thunder.core import prims, utils
from thunder.core.prims import PrimIDs
-from thunder.core.proxies import variableify, Proxy
+from thunder.core.proxies import Proxy, variableify
from thunder.core.pytree import tree_flatten, tree_map
from thunder.core.symbol import BoundSymbol
from thunder.core.trace import from_trace, TraceCtx
@@ -15,6 +16,16 @@
_cache = {}
+def disable_caching_split_forward_and_backward(fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ return fn(*args, **kwargs)
+
+ wrapper._disable_caching = True
+
+ return wrapper
+
+
def make_aug_forward_and_backward(bsym: BoundSymbol) -> tuple[Callable, Callable]:
"""
Given a bound symbol, return a pair of forward and backward functions
@@ -46,7 +57,7 @@ def make_aug_forward_and_backward(bsym: BoundSymbol) -> tuple[Callable, Callable
key = (bsym.sym, subkey := _make_cache_key(bsym.args, bsym.kwargs))
cached_result = _cache.get(key, None) if subkey is not None else None
- if cached_result is not None:
+ if cached_result is not None and not getattr(joint_forward_backward, "_disable_caching", False):
return cached_result
joint_trace = thunder.trace(inline_trace=False, use_dce=False)(joint_forward_backward, *bsym.args, **bsym.kwargs)
diff --git a/thunder/distributed/utils.py b/thunder/distributed/utils.py
index 9d9fa7bf42..48cc66d4c8 100644
--- a/thunder/distributed/utils.py
+++ b/thunder/distributed/utils.py
@@ -84,12 +84,12 @@ def prefer_comm_over_other_over_wait_over_allgather(eligible_nodes: list[Node])
# nodes over "wait_prim_impl", pick "all_gather_prim_impl" last.
def key(node: Node) -> int:
match node.bsym.sym.id:
- case (wait_prim_impl.id | unpack_for_fsdp_prim_impl.id):
+ case wait_prim_impl.id | unpack_for_fsdp_prim_impl.id:
return len(order_in_trace)
- case (reduce_scatter_prim_impl.id | all_reduce_prim_impl.id):
+ case reduce_scatter_prim_impl.id | all_reduce_prim_impl.id:
# Prefer larger communication ops over smaller ones
return -node.bsym.args[0].numel
- case (all_gather_prim_impl.id):
+ case all_gather_prim_impl.id:
return len(order_in_trace) + order_in_trace[node.bsym]
case _:
# Prefer nodes that are earlier in the trace
@@ -141,9 +141,9 @@ def prefer_comm_over_other_over_wait(eligible_nodes: list[Node]) -> int:
# nodes over "wait_prim_impl"
def key(node: Node) -> int:
match node.bsym.sym.id:
- case (wait_prim_impl.id):
+ case wait_prim_impl.id:
return len(order_in_trace)
- case (reduce_scatter_prim_impl.id | all_reduce_prim_impl.id | all_gather_prim_impl.id):
+ case reduce_scatter_prim_impl.id | all_reduce_prim_impl.id | all_gather_prim_impl.id:
# Prefer larger communication ops over smaller ones
return -node.bsym.args[0].numel
case _:
diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py
index 93c0c2d874..f256ad338b 100644
--- a/thunder/executors/torch_autograd.py
+++ b/thunder/executors/torch_autograd.py
@@ -202,7 +202,7 @@ def make_trace(func):
if not any(requires_grad_mask):
raise RuntimeError("PyTorch's Autograd interface requires at least one tensor input with requires_grad=True")
- primal_trace = make_trace(func)(*args, **kwargs)
+ primal_trace = make_trace(func)(*args, **kwargs) if not compile_data.using_jit else computation_trc
primal_trace = sort_data_parallel_syncs(primal_trace)
if compile_stats is not None:
diff --git a/thunder/executors/transformer_engineex.py b/thunder/executors/transformer_engineex.py
index ba53c3e940..c1842f11ca 100644
--- a/thunder/executors/transformer_engineex.py
+++ b/thunder/executors/transformer_engineex.py
@@ -19,6 +19,7 @@
import thunder.core.prims as prims
from thunder.core.proxies import TensorProxy, CollectionProxy
from thunder.core.symbol import Symbol
+from thunder.core.vjp_utils import disable_caching_split_forward_and_backward
from thunder.extend import OperatorExecutor, register_executor
from thunder.core.langctxs import langctx, Languages
@@ -412,6 +413,7 @@ def _linear_transform(a: TensorProxy, w: TensorProxy, b: TensorProxy) -> torch.T
return _create_fp8_linear_bound_symbol(a, w, b, is_grad_enabled=False)
+@disable_caching_split_forward_and_backward
def _linear_grad(a: TensorProxy, w: TensorProxy, b: TensorProxy) -> TensorProxy:
out, saved_for_backward = linear_forwad_rule(a, w, b)
g = prims.get_grad(out)
diff --git a/thunder/tests/litgpt_model.py b/thunder/tests/litgpt_model.py
index 23b51545a0..13ab52f44b 100644
--- a/thunder/tests/litgpt_model.py
+++ b/thunder/tests/litgpt_model.py
@@ -1,4 +1,5 @@
"""Taken from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py"""
+
import torch
import torch.nn as nn
diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py
index 54b27f4964..377a284a6c 100644
--- a/thunder/tests/test_core.py
+++ b/thunder/tests/test_core.py
@@ -2265,6 +2265,34 @@ def func(x, y, *, z):
assert bsym.flat_args == [1, 2, 3]
+@instantiate(dtypes=NOTHING)
+def test_preserve_weight_names(executor, device: str, dtype: dtypes.dtype):
+ import inspect
+
+ class MLP(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.fc1 = torch.nn.Linear(3, 4)
+ self.fc2 = torch.nn.Linear(4, 5)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.fc2(x)
+ return x
+
+ model = MLP().to(device=device, dtype=ltorch.to_torch_dtype(dtype))
+ x = torch.randn(2, 3, device=device, dtype=ltorch.to_torch_dtype(dtype))
+
+ compiled = thunder.jit(model, executors=executor.executors_list())
+ compiled(x)
+ traces = thunder.last_traces(compiled)
+ sig = inspect.signature(traces[-1].python_callable())
+ assert "t_fc1_bias" in sig.parameters
+ assert "t_fc1_weight" in sig.parameters
+ assert "t_fc2_bias" in sig.parameters
+ assert "t_fc2_weight" in sig.parameters
+
+
# @instantiate(
# dtypes=NOTHING,
# )
diff --git a/thunder/tests/test_transformer_engine_executor.py b/thunder/tests/test_transformer_engine_executor.py
index 41b6f83e36..e8abcc7c09 100644
--- a/thunder/tests/test_transformer_engine_executor.py
+++ b/thunder/tests/test_transformer_engine_executor.py
@@ -33,7 +33,7 @@ def test_te_linear_forward_backward():
# TE inputs (3D input)
x_te = torch.randn(3, 768, 4096, device=device, dtype=dtype, requires_grad=True)
te_linear1 = te.Linear(4096, 4096, params_dtype=dtype)
- te_linear2 = te.Linear(4096, 2048, params_dtype=dtype)
+ te_linear2 = te.Linear(4096, 4096, params_dtype=dtype)
# thunder inputs
x = x_te.detach().clone()