diff --git a/.azure/gpu-tests.yml b/.azure/gpu-tests.yml
index 22ba01eddc..1bf99b0aaa 100644
--- a/.azure/gpu-tests.yml
+++ b/.azure/gpu-tests.yml
@@ -65,8 +65,14 @@ jobs:
# drop pt from requirements so not to interfere with the existing one
bash .azure/remove-torch-lines.sh requirements/base.txt
cat requirements/base.txt
+
# double check on test requirements
pip install -r requirements/test.txt
+
+ # https://docs.codecov.com/docs/codecov-uploader
+ curl -Os https://uploader.codecov.io/latest/linux/codecov
+ chmod +x codecov
+
# install this package
python setup.py develop
displayName: 'Install package & ...'
@@ -85,6 +91,12 @@ jobs:
--durations=250 \
--numprocesses=9 \
--ignore=thunder/tests/distributed --ignore=thunder/tests/test_networks.py
+ # compile coverage results
+ python -m coverage report
+ python -m coverage xml
+ # upload to codecov
+ ./codecov --token=$(CODECOV_TOKEN) --commit=$(Build.SourceVersion) \
+ --flags=gpu,pytest,regular --name="GPU-coverage" --env=linux,azure
condition: ne(variables['testing'], 'distributed')
displayName: 'Testing: regular'
@@ -95,6 +107,12 @@ jobs:
thunder/tests/test_networks.py \
-m "not standalone" \
-v --random-order-seed=42 --durations=0 --numprocesses=3
+ # compile coverage results
+ python -m coverage report
+ python -m coverage xml
+ # upload to codecov
+ ./codecov --token=$(CODECOV_TOKEN) --commit=$(Build.SourceVersion) \
+ --flags=gpu,pytest,networks --name="GPU-coverage" --env=linux,azure
condition: ne(variables['testing'], 'distributed')
displayName: 'Testing: networks'
@@ -108,6 +126,12 @@ jobs:
- bash: |
# run all found tests in given past as standalone
bash scripts/run_standalone_tests.sh "thunder/tests/distributed"
+ # compile coverage results
+ python -m coverage report
+ python -m coverage xml
+ # upload to codecov
+ ./codecov --token=$(CODECOV_TOKEN) --commit=$(Build.SourceVersion) \
+ --flags=gpu,pytest,distributed --name="GPU-coverage" --env=linux,azure
condition: eq(variables['testing'], 'distributed')
displayName: 'Testing: distributed'
diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
index a9d21a849b..1d016662eb 100644
--- a/.github/PULL_REQUEST_TEMPLATE.md
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -1,10 +1,13 @@
-# Before submitting
+
+ Before submitting
- [ ] Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
- [ ] Did you read the [contributor guideline](https://github.com/Lightning-AI/pytorch-lightning/blob/main/.github/CONTRIBUTING.md), Pull Request section?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?
+
+
## What does this PR do?
Fixes # (issue).
diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml
index cbd8fb2aa6..310279fbbd 100644
--- a/.github/workflows/ci-testing.yml
+++ b/.github/workflows/ci-testing.yml
@@ -114,15 +114,15 @@ jobs:
coverage report
coverage xml
- #- name: Upload coverage to Codecov
- # uses: codecov/codecov-action@v3
- # with:
- # token: ${{ secrets.CODECOV_TOKEN }}
- # file: ./coverage.xml
- # flags: unittests
- # env_vars: OS,PYTHON
- # name: codecov-umbrella
- # fail_ci_if_error: false
+ - name: Upload coverage to Codecov
+ uses: codecov/codecov-action@v3
+ with:
+ token: ${{ secrets.CODECOV_TOKEN }}
+ file: ./coverage.xml
+ flags: unittests
+ env_vars: OS,PYTHON
+ name: codecov-umbrella
+ fail_ci_if_error: false
testing-guardian:
diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml
index d17ae79eef..b8d87e466d 100644
--- a/.github/workflows/docs-build.yml
+++ b/.github/workflows/docs-build.yml
@@ -14,7 +14,7 @@ defaults:
shell: bash
jobs:
- build-docs:
+ docs-make:
uses: Lightning-AI/utilities/.github/workflows/check-docs.yml@v0.11.0
with:
python-version: "3.10"
@@ -28,7 +28,7 @@ jobs:
env:
GCP_TARGET: "gs://lightning-docs-thunder"
steps:
- - uses: actions/download-artifact@v4
+ - uses: actions/download-artifact@v3
with:
name: docs-html-${{ github.sha }}
path: docs/build/
@@ -50,7 +50,7 @@ jobs:
# Uploading docs to GCS, so they can be served on lightning.ai
- name: Upload docs/thunder/latest to GCS 🪣
- if: github.ref == 'refs/heads/master'
+ if: github.ref == 'refs/heads/main'
run: gsutil -m rsync -d -R docs/build/html/ ${GCP_TARGET}/latest
# Uploading docs to GCS, so they can be served on lightning.ai
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 7b7b8d1f6c..f16e1ef98f 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -17,6 +17,9 @@ repos:
- id: check-toml
- id: check-json
- id: check-added-large-files
+ with:
+ maxkb: 250
+ enforce-all: true
- id: check-docstring-first
- id: detect-private-key
diff --git a/README.md b/README.md
index f89dce73b0..ce5399d570 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,6 @@
-
+
+
diff --git a/dockers/ubuntu-cuda/Dockerfile b/dockers/ubuntu-cuda/Dockerfile
index e815d827f6..2213921897 100644
--- a/dockers/ubuntu-cuda/Dockerfile
+++ b/dockers/ubuntu-cuda/Dockerfile
@@ -24,6 +24,7 @@ ARG CUDNN_FRONTEND_CHECKOUT="v1.1.0"
ARG PYTHON_VERSION="3.10"
ARG TORCH_VERSION="2.2.1"
ARG TRITON_VERSION="2.2.0"
+ARG TORCH_INSTALL="stable"
SHELL ["/bin/bash", "-c"]
# https://techoverflow.net/2019/05/18/how-to-fix-configuring-tzdata-interactive-input-when-building-docker-images/
@@ -96,7 +97,7 @@ ENV \
TORCH_CUDA_ARCH_LIST="8.0" \
CUDA_SELECT_NVCC_ARCH_FLAGS="8.0"
-ARG TORCH_INSTALL="wheel"
+ARG TORCH_INSTALL
RUN \
if [ "${TORCH_INSTALL}" == "source" ]; then \
@@ -122,15 +123,26 @@ RUN \
--index-url="https://download.pytorch.org/whl/cu${CUDA_VERSION_MM//'.'/''}"; \
fi
+ARG TORCH_INSTALL
+
RUN \
- # building nvFuser from source
- git clone https://github.com/NVIDIA/Fuser.git && \
- cd Fuser && \
- git submodule update --init --recursive && \
- pip install -r requirements.txt && \
- python setup.py install --no-test --no-benchmark && \
- cd .. && \
- rm -rf Fuser
+ if [ "${TORCH_INSTALL}" == "source" ]; then \
+ # building nvFuser from source
+ git clone https://github.com/NVIDIA/Fuser.git && \
+ cd Fuser && \
+ git submodule update --init --recursive && \
+ pip install -r requirements.txt && \
+ python setup.py install --no-test --no-benchmark && \
+ cd .. && \
+ rm -rf Fuser ; \
+ elif [ "${TORCH_INSTALL}" == "test" ]; then \
+ echo "Not supported option" ; \
+ else \
+ # installing pytorch from wheels \
+ CUDA_VERSION_MM=${CUDA_VERSION%.*} && \
+ TORCH_VERSION_MM=${TORCH_VERSION%.*} && \
+ pip install -U "nvfuser-cu${CUDA_VERSION_MM/./}-torch${TORCH_VERSION_MM/./}" ; \
+ fi
RUN \
ls -lh requirements/ && \
diff --git a/docs/source/_static/images/LightningThunderDarkModewByline.png b/docs/source/_static/images/LightningThunderDarkModewByline.png
new file mode 100644
index 0000000000..310739641c
Binary files /dev/null and b/docs/source/_static/images/LightningThunderDarkModewByline.png differ
diff --git a/docs/source/_static/images/LightningThunderLightModewByline.png b/docs/source/_static/images/LightningThunderLightModewByline.png
new file mode 100644
index 0000000000..7effda5ec9
Binary files /dev/null and b/docs/source/_static/images/LightningThunderLightModewByline.png differ
diff --git a/docs/source/_static/images/normalized_training_throughput_zero2.png b/docs/source/_static/images/normalized_training_throughput_zero2.png
index be6e5888c3..60dc0725a9 100644
Binary files a/docs/source/_static/images/normalized_training_throughput_zero2.png and b/docs/source/_static/images/normalized_training_throughput_zero2.png differ
diff --git a/docs/source/_static/images/training_throughput_single.png b/docs/source/_static/images/training_throughput_single.png
index 6c0a7029a4..ad66af29db 100644
Binary files a/docs/source/_static/images/training_throughput_single.png and b/docs/source/_static/images/training_throughput_single.png differ
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 052fa2437c..598dc42053 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -92,7 +92,7 @@ def _transform_changelog(path_in: str, path_out: str) -> None:
"sphinx.ext.linkcode",
"sphinx.ext.autosummary",
"sphinx.ext.napoleon",
- "sphinx.ext.imgmath",
+ "sphinx.ext.mathjax",
"myst_parser",
"nbsphinx",
"sphinx_autodoc_typehints",
@@ -209,6 +209,11 @@ def _transform_changelog(path_in: str, path_out: str) -> None:
(master_doc, project + ".tex", project + " Documentation", author, "manual"),
]
+# MathJax configuration
+mathjax3_config = {
+ "tex": {"packages": {"[+]": ["ams", "newcommand", "configMacros"]}},
+}
+
# -- Options for manual page output ------------------------------------------
# One entry per manual page. List of tuples
diff --git a/notebooks/dev_tutorials/fsdp_tutorial.ipynb b/notebooks/dev_tutorials/fsdp_tutorial.ipynb
index a4f61b47c3..b8cef2a2ff 100644
--- a/notebooks/dev_tutorials/fsdp_tutorial.ipynb
+++ b/notebooks/dev_tutorials/fsdp_tutorial.ipynb
@@ -1764,7 +1764,7 @@
"%%writefile thunder_fsdp_simple_example.py\n",
"\n",
"# imports\n",
- "from thunder.tests.lit_gpt_model import GPT, Config\n",
+ "from thunder.tests.litgpt_model import GPT, Config\n",
"import torch\n",
"import torch.distributed\n",
"import thunder\n",
diff --git a/notebooks/zero_to_thunder.ipynb b/notebooks/zero_to_thunder.ipynb
index a1a888cc72..ffd6b5fe1c 100644
--- a/notebooks/zero_to_thunder.ipynb
+++ b/notebooks/zero_to_thunder.ipynb
@@ -312,8 +312,8 @@
}
],
"source": [
- "from lit_gpt import GPT\n",
- "from thunder.tests.lit_gpt_model import Config\n",
+ "from litgpt import GPT\n",
+ "from thunder.tests.litgpt_model import Config\n",
"cfg = Config.from_name('Llama-2-7b-hf')\n",
"cfg.n_layer = 16 # fewer layers\n",
"torch.set_default_dtype(torch.bfloat16)\n",
@@ -3326,7 +3326,7 @@
],
"source": [
"%%writefile zero_to_thunder_fsdp_simple_example.py\n",
- "from thunder.tests.lit_gpt_model import GPT, Config\n",
+ "from thunder.tests.litgpt_model import GPT, Config\n",
"import os\n",
"import torch, torch.distributed\n",
"import thunder, thunder.distributed\n",
@@ -3470,7 +3470,7 @@
},
"outputs": [],
"source": [
- "import lit_gpt\n",
+ "import litgpt\n",
"def apply_rope_copy(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:\n",
" head_size = x.size(-1)\n",
" x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)\n",
@@ -3493,7 +3493,7 @@
"\n",
"Say we have a function `apply_rope` applying the RoPE transformation in PyTorch.\n",
"\n",
- "In thunder, we define a *meta* function that only defines the metadata (like shapes) of outputs and the actual implementation for each operator and then register the pair with our executor using the `register_operator` function and tell it to use the new symbol instead of the original function `lit_gpt.model.apply_rope`.\n"
+ "In thunder, we define a *meta* function that only defines the metadata (like shapes) of outputs and the actual implementation for each operator and then register the pair with our executor using the `register_operator` function and tell it to use the new symbol instead of the original function `litgpt.model.apply_rope`.\n"
]
},
{
@@ -3504,17 +3504,17 @@
"outputs": [],
"source": [
"import torch, thunder\n",
- "from thunder.tests.lit_gpt_model import GPT\n",
+ "from thunder.tests.litgpt_model import GPT\n",
"from thunder import TensorProxy\n",
"\n",
"def apply_rope_impl(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:\n",
- " return lit_gpt.model.apply_rope(x, cos, sin)\n",
+ " return litgpt.model.apply_rope(x, cos, sin)\n",
"\n",
"def apply_rope_meta(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> TensorProxy:\n",
" return TensorProxy(like=x)\n",
"\n",
"apply_rope = my_ex.register_operator('apply_rope', like=apply_rope_meta, fn=apply_rope_impl,\n",
- " replaces=lit_gpt.model.apply_rope)"
+ " replaces=litgpt.model.apply_rope)"
]
},
{
@@ -3569,7 +3569,7 @@
"with torch.device('cuda'): m = GPT.from_name('llama2-like'); Q = torch.randn(2, 128, 4096, 16)\n",
"\n",
"def test_apply_rope(x, m):\n",
- " return lit_gpt.model.apply_rope(x, m.cos, m.sin)\n",
+ " return litgpt.model.apply_rope(x, m.cos, m.sin)\n",
"\n",
"thunder_apply_rope = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors()) \n",
"\n",
diff --git a/requirements/notebooks.txt b/requirements/notebooks.txt
index 47a14902ca..7c37419cd6 100644
--- a/requirements/notebooks.txt
+++ b/requirements/notebooks.txt
@@ -1 +1,3 @@
ipython[all] ==8.22.2
+
+litgpt @ git+https://github.com/Lightning-AI/lit-gpt@24d5eba1724c953b7506edc041a7da1ce226c129
diff --git a/requirements/test.txt b/requirements/test.txt
index c882272833..621ebaa7c5 100644
--- a/requirements/test.txt
+++ b/requirements/test.txt
@@ -11,7 +11,7 @@ expecttest ==0.2.1 # for test_ddp.py
hypothesis ==6.99.10 # for test_ddp.py
numpy # for test_ops.py
einops # for test_einops.py
-lit_gpt @ git+https://github.com/Lightning-AI/lit-gpt@f241d94df59d82b2017bfdcd3800ac8779eb45f5
+litgpt @ git+https://github.com/Lightning-AI/lit-gpt@24d5eba1724c953b7506edc041a7da1ce226c129
absl-py # thunder/benchmarks/test_benchmark_litgpt.py
pandas # thunder/benchmarks/test_benchmark_litgpt.py
xlsxwriter # thunder/benchmarks/test_benchmark_litgpt.py
diff --git a/thunder/benchmarks/__init__.py b/thunder/benchmarks/__init__.py
index e9ba261c45..f4b0f2ac03 100644
--- a/thunder/benchmarks/__init__.py
+++ b/thunder/benchmarks/__init__.py
@@ -24,9 +24,9 @@
import thunder.core.devices as Devices
from thunder.core.transforms import grad, clear_grads, populate_grads
import thunder.executors as executors
-from thunder.tests import nanogpt_model, hf_bart_self_attn, lit_gpt_model
+from thunder.tests import nanogpt_model, hf_bart_self_attn, litgpt_model
from thunder.tests.make_tensor import make_tensor, make_tensor_like
-from thunder.tests.lit_gpt_model import Config as LitGPTConfig
+from thunder.tests.litgpt_model import Config as LitGPTConfig
# List of all benchmarks
benchmarks: list = []
@@ -1875,7 +1875,7 @@ class LlamaMLPBenchmark(Benchmark, metaclass=UserFacingBenchmarkMeta):
_args = (
BenchmarkArg(
name="config",
- description="The Lit-GPT config to use. Default is 'Llama-2-7b-hf'. See the lit_gpt_model.py for details.",
+ description="The Lit-GPT config to use. Default is 'Llama-2-7b-hf'. See the litgpt_model.py for details.",
),
BenchmarkArg(
name="batchdims",
@@ -1935,7 +1935,7 @@ def make_batch(self) -> tuple[list, dict]:
def fn(self) -> Callable:
module = (
- lit_gpt_model.LLaMAMLP(self.config)
+ litgpt_model.LLaMAMLP(self.config)
.to(device=self.device, dtype=self.tdtype)
.requires_grad_(self.requires_grad)
)
@@ -1946,7 +1946,7 @@ class LitGPTCausalSelfAttentionBenchmark(Benchmark, metaclass=UserFacingBenchmar
_args = (
BenchmarkArg(
name="config",
- description="The Lit-GPT config to use. Default is 'Llama-2-7b-hf'. See the lit_gpt_model.py for details.",
+ description="The Lit-GPT config to use. Default is 'Llama-2-7b-hf'. See the litgpt_model.py for details.",
),
BenchmarkArg(
name="batchdims",
@@ -2005,7 +2005,7 @@ def make_batch(self) -> tuple[list, dict]:
def fn(self) -> Callable:
module = (
- lit_gpt_model.CausalSelfAttention(self.config)
+ litgpt_model.CausalSelfAttention(self.config)
.to(device=self.device, dtype=self.tdtype)
.requires_grad_(self.requires_grad)
)
@@ -2086,7 +2086,7 @@ def make_batch(self) -> tuple[list, dict]:
def fn(self) -> Callable:
module = (
- lit_gpt_model.RMSNorm(self.size, self.dim, self.eps)
+ litgpt_model.RMSNorm(self.size, self.dim, self.eps)
.to(device=self.device, dtype=self.tdtype)
.requires_grad_(self.requires_grad)
)
@@ -2168,7 +2168,7 @@ def make_batch(self) -> tuple[list, dict]:
def fn(self) -> Callable:
gpt = (
- lit_gpt_model.GPT(self.config)
+ litgpt_model.GPT(self.config)
.to(device=self.device, dtype=self.model_tdtype)
.requires_grad_(self.requires_grad)
)
@@ -2199,7 +2199,7 @@ def __init__(self, config, use_apex) -> None:
super().__init__()
self.config = config
- self.apply_rope = lit_gpt_model.apply_rope
+ self.apply_rope = litgpt_model.apply_rope
self.use_apex = use_apex
def forward(
@@ -2254,7 +2254,7 @@ class LlamaQKVSplitRopeBenchmark(Benchmark, metaclass=UserFacingBenchmarkMeta):
_args = (
BenchmarkArg(
name="config",
- description="The Lit-GPT config to use. Default is 'Llama-2-7b-hf'. See the lit_gpt_model.py for details.",
+ description="The Lit-GPT config to use. Default is 'Llama-2-7b-hf'. See the litgpt_model.py for details.",
),
BenchmarkArg(
name="batchdims",
@@ -2610,7 +2610,7 @@ def __init__(
dtype: dtypes.dtype = thunder.bfloat16,
requires_grad: bool = True,
) -> None:
- from thunder.tests.lit_gpt_model import Config
+ from thunder.tests.litgpt_model import Config
litgptconfig = Config.from_name(config) if not isinstance(config, Config) else config
nanogptconfig = NanoGPTConfig(
@@ -2793,7 +2793,7 @@ def __init__(
# Sets required benchmark parameters
self.devices: list[str] = [device]
- self.cos, self.sin = lit_gpt_model.build_rope_cache(
+ self.cos, self.sin = litgpt_model.build_rope_cache(
seq_len=seq_length, n_elem=self.config.rope_n_elem, device=self.device
)
@@ -2806,9 +2806,7 @@ def make_batch(self) -> tuple[list, dict]:
def fn(self) -> Callable:
model = (
- lit_gpt_model.Block(self.config)
- .to(device=self.device, dtype=self.tdtype)
- .requires_grad_(self.requires_grad)
+ litgpt_model.Block(self.config).to(device=self.device, dtype=self.tdtype).requires_grad_(self.requires_grad)
)
return model
diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py
index 9120584989..877f99f38e 100644
--- a/thunder/benchmarks/benchmark_litgpt.py
+++ b/thunder/benchmarks/benchmark_litgpt.py
@@ -7,7 +7,7 @@
import torch.distributed as torch_dist
import thunder
-from thunder.tests.lit_gpt_model import Config, GPT, Block
+from thunder.tests.litgpt_model import Config, GPT, Block
from lightning.fabric.utilities.throughput import measure_flops
from lightning.fabric.utilities import Throughput
diff --git a/thunder/benchmarks/distributed.py b/thunder/benchmarks/distributed.py
index c46a699ca5..3437c47e30 100644
--- a/thunder/benchmarks/distributed.py
+++ b/thunder/benchmarks/distributed.py
@@ -14,7 +14,7 @@
LitGPTBenchmark,
LitGPTConfig,
)
-from thunder.tests.lit_gpt_model import name_to_config
+from thunder.tests.litgpt_model import name_to_config
from thunder.distributed import FSDPBucketingStrategy
from thunder.distributed import FSDPType
@@ -299,7 +299,7 @@ def parse_args() -> argparse.Namespace:
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from thunder.benchmarks import get_default_torch_fsdp_executor
from thunder.tests.nanogpt_model import Block as NanoGPTBlock
- from thunder.tests.lit_gpt_model import Block as GPTBlock
+ from thunder.tests.litgpt_model import Block as GPTBlock
sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
auto_wrap_policies = (
diff --git a/thunder/benchmarks/targets.py b/thunder/benchmarks/targets.py
index 9f02b23d35..9b38e26c3d 100644
--- a/thunder/benchmarks/targets.py
+++ b/thunder/benchmarks/targets.py
@@ -39,7 +39,7 @@
thunder_sdpa_torch_compile_nvfuser_executor,
)
-from thunder.tests.lit_gpt_model import Config as LitGPTConfig
+from thunder.tests.litgpt_model import Config as LitGPTConfig
APEX_FUSED_ROPE_AVAILABLE: bool = package_available("fused_rotary_positional_embedding")
diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py
index e626a77f2a..29046d5778 100644
--- a/thunder/core/jit_ext.py
+++ b/thunder/core/jit_ext.py
@@ -7,7 +7,8 @@
from collections.abc import Callable, Sequence
import weakref
import random
-from functools import partial, wraps
+from functools import partial, wraps, reduce
+import operator
import copy
import contextvars
from contextlib import contextmanager
@@ -1347,6 +1348,20 @@ def bind_inputs(name, trace, input_vars, input_proxies):
trace.args = input_proxies
+def _get_process_group_from(*fn_and_args) -> Optional["ProcessGroup"]:
+ # `ddp` and `fsdp` transforms add attribute `procses_group_for_ddp`
+ # on the Module that they wrap. This module could be passed to `thunder.jit`
+ # as the function to be jitted or as an argument of the function to be jitted.
+ found_pg = None
+ for fn_or_arg in fn_and_args:
+ pg = getattr(fn_or_arg, "process_group_for_ddp", None)
+ if pg is not None and found_pg is None:
+ found_pg = pg
+ elif pg is not None and pg != found_pg:
+ raise NotImplementedError("jitting modules with different ProcessGroup is not supported currently.")
+ return found_pg
+
+
def thunder_general_jit(
fn: Callable, args, kwargs, /, *, sharp_edges: SHARP_EDGES_OPTIONS
) -> tuple[TraceCtx, TraceCtx]:
@@ -1369,7 +1384,7 @@ def thunder_general_jit(
si.varkwargs = ("kwargs", None)
prologue_trace._siginfo = si
- process_group_for_ddp = getattr(fn, "process_group_for_ddp", None)
+ process_group_for_ddp: Optional["ProcessGroup"] = _get_process_group_from(fn, *args, *kwargs.values())
ctx: GeneralJitCtx = GeneralJitCtx(
prologue_trace, computation_trace, sharp_edges=sharp_edges, process_group_for_ddp=process_group_for_ddp
)
diff --git a/thunder/executors/triton_crossentropy.py b/thunder/executors/triton_crossentropy.py
index fc22770434..277e605e78 100644
--- a/thunder/executors/triton_crossentropy.py
+++ b/thunder/executors/triton_crossentropy.py
@@ -1,635 +1,10 @@
-import math
-from enum import Enum
-
-import torch
-
-
-import thunder.torch as ltorch
from thunder.executors import triton_utils
+from thunder.extend import OperatorExecutor
-
-# Requires triton 2.1 or greater
-min_triton_version = "2.1"
triton_version: None | str = triton_utils.triton_version()
-assert triton_version is not None, f"Trying to import a Triton executor, but Triton is unavailable"
-TRITON_AVAILABLE: bool = triton_utils.is_triton_version_at_least(min_triton_version)
-assert (
- TRITON_AVAILABLE
-), f"Trying to import a Triton executor, but it requires Triton version {min_triton_version} or greater, and the current Triton version is {triton_version}"
-
-from thunder.extend import OperatorExecutor, register_executor
-
-triton_ex: OperatorExecutor = OperatorExecutor("triton", version=triton_version)
-register_executor(triton_ex)
-
-import triton # noqa: E402
-import triton.language as tl # noqa: E402
-
-# Temporarily borrowed from https://github.com/openai/triton
-FORWARD_NUM_STAGES = 1
-
-
-class TritonDtype(Enum):
- kFP16 = 0
- kBF16 = 1
- kFP32 = 2
- kFP64 = 3
-
-
-_TORCH2DTYPE = {
- torch.float16: TritonDtype.kFP16,
- torch.bfloat16: TritonDtype.kBF16,
- torch.float32: TritonDtype.kFP32,
- torch.float64: TritonDtype.kFP64,
-}
-_DTYPE2TRITON = {
- TritonDtype.kFP16: tl.float16,
- TritonDtype.kBF16: tl.bfloat16,
- TritonDtype.kFP32: tl.float32,
- TritonDtype.kFP64: tl.float64,
-}
-
-
-@triton.jit
-def _class_indices_forward(
- LOGITS,
- PROBS,
- IDX,
- LOSS,
- weight,
- N,
- WEIGHT_BUFFER,
- smoothing_factor,
- log_size_logits,
- WEIGHTS: tl.constexpr,
- CLASS_INDICES: tl.constexpr,
- LABEL_SMOOTHING: tl.constexpr,
- IGNORE_INDEX: tl.constexpr,
- BUFFER_DTYPE: tl.constexpr,
- BLOCK: tl.constexpr,
-):
- buffer_dtype = _DTYPE2TRITON[BUFFER_DTYPE.value]
- row = tl.program_id(0)
- cols = tl.arange(0, BLOCK)
- logit_start_ptrs = LOGITS + row * N
- logit_ptrs = logit_start_ptrs + cols
- m_prev = -float("inf")
- l_prev = 0.0
- m_prev = m_prev.to(buffer_dtype)
- l_prev = l_prev.to(buffer_dtype)
-
- for start_n in range(0, tl.cdiv(N, BLOCK)):
- row_logits = tl.load(
- logit_ptrs,
- mask=cols < N - (start_n * BLOCK),
- other=-float("inf"),
- ).to(buffer_dtype)
-
- m_curr = tl.maximum(tl.max(row_logits, 0), m_prev)
- l_prev *= tl.exp(m_prev - m_curr)
- p = tl.exp(row_logits - m_curr)
- l_curr = tl.sum(p, 0) + l_prev
- l_prev = l_curr
- m_prev = m_curr
- logit_ptrs += BLOCK
- logit_ptrs = logit_start_ptrs + cols
- output_ptrs = PROBS + row * N + cols
- WRIT_PROBS = PROBS + row * N + cols
- if LABEL_SMOOTHING:
- sum_total = 0.0
- sum_total = sum_total.to(buffer_dtype)
- weights_total = 0.0
- weights_total = weights_total.to(buffer_dtype)
- if WEIGHTS:
- weight_ptr = weight + cols
-
- l_prev_log = tl.log(l_prev)
- # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
- for start_n in range(0, tl.cdiv(N, BLOCK)):
- row_logits = tl.load(
- logit_ptrs,
- mask=cols < N - start_n * BLOCK,
- other=l_prev_log + m_prev,
- ).to(buffer_dtype)
- if LABEL_SMOOTHING and WEIGHTS:
- full_weights_val = tl.load(weight_ptr, mask=cols < N - start_n * BLOCK, other=0.0)
- weights_total += tl.sum(full_weights_val, 0)
-
- row_minus_max = row_logits - m_prev
- log_softmax = l_prev_log - row_minus_max
-
- if LABEL_SMOOTHING and WEIGHTS:
- log_softmax *= full_weights_val
-
- if LABEL_SMOOTHING:
- sum_total += tl.sum(log_softmax, 0)
- # Store it back
-
- tl.store(
- WRIT_PROBS,
- log_softmax,
- mask=cols < N - start_n * BLOCK,
- )
- logit_ptrs += BLOCK
- WRIT_PROBS += BLOCK
- if LABEL_SMOOTHING and WEIGHTS:
- weight_ptr += BLOCK
-
- idx = tl.load(IDX + row)
- use_class = 0.0
- if IGNORE_INDEX >= 0:
- use_class = idx == IGNORE_INDEX
- READ_PROBS = PROBS + row * N + idx
- tl.debug_barrier()
- # write-back loss
- probs = tl.load(READ_PROBS)
- if WEIGHTS and not LABEL_SMOOTHING:
- weight_ptr = weight + idx
- weights_val = tl.load(weight_ptr)
- probs = weights_val * probs
- if LABEL_SMOOTHING:
- tl.store(WEIGHT_BUFFER + row, weights_total)
- probs = (1 - smoothing_factor) * probs + smoothing_factor * (sum_total) / N
- probs = probs * (1.0 - use_class)
-
- tl.store(LOSS + row, probs)
-
-
-@triton.jit
-def _class_probs_forward(
- LOGITS,
- PROBS,
- IDX,
- LOSS,
- weight,
- N,
- WEIGHT_BUFFER,
- smoothing_factor,
- log_size_logits,
- WEIGHTS: tl.constexpr,
- CLASS_INDICES: tl.constexpr,
- LABEL_SMOOTHING: tl.constexpr,
- IGNORE_INDEX: tl.constexpr,
- BUFFER_DTYPE: tl.constexpr,
- BLOCK: tl.constexpr,
-):
- buffer_dtype = _DTYPE2TRITON[BUFFER_DTYPE.value]
- row = tl.program_id(0)
- cols = tl.arange(0, BLOCK)
- logit_start_ptrs = LOGITS + row * N
- logit_ptrs = logit_start_ptrs + cols
- m_prev = -float("inf")
- l_prev = 0.0
- m_prev = m_prev.to(buffer_dtype)
- l_prev = l_prev.to(buffer_dtype)
-
- for start_n in range(0, tl.cdiv(N, BLOCK)):
- row_logits = tl.load(
- logit_ptrs,
- mask=cols < N - (start_n * BLOCK),
- other=-float("inf"),
- ).to(buffer_dtype)
-
- m_curr = tl.maximum(tl.max(row_logits, 0), m_prev)
- l_prev *= tl.exp(m_prev - m_curr)
- p = tl.exp(row_logits - m_curr)
- l_curr = tl.sum(p, 0) + l_prev
- l_prev = l_curr
- m_prev = m_curr
- logit_ptrs += BLOCK
- logit_ptrs = logit_start_ptrs + cols
- output_ptrs = PROBS + row * N + cols
- WRIT_PROBS = PROBS + row * N + cols
-
- sum_total = 0.0
- weights_total = 0.0
- sum_total = sum_total.to(buffer_dtype)
- weights_total = weights_total.to(buffer_dtype)
- idx_ptr = IDX + row * N + cols
- if WEIGHTS:
- weight_ptr = weight + cols
-
- l_prev_log = tl.log(l_prev)
- # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
- for start_n in range(0, tl.cdiv(N, BLOCK)):
- row_logits = tl.load(
- logit_ptrs,
- mask=cols < N - start_n * BLOCK,
- other=l_prev_log + m_prev,
- ).to(buffer_dtype)
- idx = tl.load(idx_ptr, mask=cols < N - start_n * BLOCK, other=0.0)
- full_weights_val = (1.0 - smoothing_factor) * idx + smoothing_factor / N
- if WEIGHTS:
- weights_val = tl.load(weight_ptr, mask=cols < N - start_n * BLOCK, other=0.0)
- full_weights_val = weights_val * full_weights_val
- else:
- full_weights_val = tl.where(cols < N - start_n * BLOCK, full_weights_val, 0.0)
- weights_total += tl.sum(full_weights_val, 0)
-
- row_minus_max = row_logits - m_prev
- log_softmax = l_prev_log - row_minus_max
-
- log_softmax *= full_weights_val
- sum_total += tl.sum(log_softmax, 0)
- # Store it back
-
- tl.store(
- WRIT_PROBS,
- log_softmax,
- mask=cols < N - start_n * BLOCK,
- )
- logit_ptrs += BLOCK
- WRIT_PROBS += BLOCK
- idx_ptr += BLOCK
- if WEIGHTS:
- weight_ptr += BLOCK
-
- tl.store(WEIGHT_BUFFER + row, weights_total)
- probs = sum_total
-
- tl.store(LOSS + row, probs)
-
-
-@triton.autotune(
- configs=[
- # fmt: off
- triton.Config({'BLOCK': 1024}, num_stages=FORWARD_NUM_STAGES, num_warps=1),
- triton.Config({'BLOCK': 2048}, num_stages=FORWARD_NUM_STAGES, num_warps=8),
- triton.Config({'BLOCK': 4096}, num_stages=FORWARD_NUM_STAGES, num_warps=8),
- triton.Config({'BLOCK': 8192}, num_stages=FORWARD_NUM_STAGES, num_warps=16),
- triton.Config({'BLOCK': 16384}, num_stages=FORWARD_NUM_STAGES, num_warps=16),
- # fmt: on
- ],
- key=[
- "N",
- "CLASS_INDICES",
- "log_size_logits",
- "BUFFER_DTYPE",
- ],
-)
-@triton.jit
-def _forward(
- LOGITS,
- PROBS,
- IDX,
- LOSS,
- weight,
- N,
- WEIGHT_BUFFER,
- smoothing_factor,
- log_size_logits,
- WEIGHTS: tl.constexpr,
- CLASS_INDICES: tl.constexpr,
- LABEL_SMOOTHING: tl.constexpr,
- IGNORE_INDEX: tl.constexpr,
- BUFFER_DTYPE: tl.constexpr,
- BLOCK: tl.constexpr,
-):
- if CLASS_INDICES:
- _class_indices_forward(
- LOGITS,
- PROBS,
- IDX,
- LOSS,
- weight,
- N,
- WEIGHT_BUFFER,
- smoothing_factor,
- log_size_logits,
- WEIGHTS,
- CLASS_INDICES,
- LABEL_SMOOTHING,
- IGNORE_INDEX,
- BUFFER_DTYPE,
- BLOCK,
- )
- else:
- _class_probs_forward(
- LOGITS,
- PROBS,
- IDX,
- LOSS,
- weight,
- N,
- WEIGHT_BUFFER,
- smoothing_factor,
- log_size_logits,
- WEIGHTS,
- CLASS_INDICES,
- LABEL_SMOOTHING,
- IGNORE_INDEX,
- BUFFER_DTYPE,
- BLOCK,
- )
-
-
-@triton.autotune(
- configs=[
- # fmt: off
- triton.Config({'BLOCK': 1024}, num_stages=1, num_warps=1),
- triton.Config({'BLOCK': 2048}, num_stages=1, num_warps=8),
- triton.Config({'BLOCK': 4096}, num_stages=1, num_warps=8),
- triton.Config({'BLOCK': 8192}, num_stages=1, num_warps=16),
- triton.Config({'BLOCK': 16384}, num_stages=1, num_warps=16),
- # fmt: on
- ],
- key=[
- "N",
- "CLASS_INDICES",
- "log_size_logits",
- "BUFFER_DTYPE",
- ],
-)
-@triton.jit
-def _backward(
- PROBS,
- IDX,
- DPROBS,
- dprob_stride,
- DIN,
- weight,
- N,
- WEIGHT_BUFFER,
- smoothing_factor,
- log_size_logits,
- WEIGHTS: tl.constexpr,
- CLASS_INDICES: tl.constexpr,
- LABEL_SMOOTHING: tl.constexpr,
- IGNORE_INDEX: tl.constexpr,
- BUFFER_DTYPE: tl.constexpr,
- BLOCK: tl.constexpr,
-):
- buffer_dtype = _DTYPE2TRITON[BUFFER_DTYPE.value]
- row = tl.program_id(0)
- start_n = tl.program_id(1)
- cols = tl.arange(0, BLOCK)
- PROBS = PROBS + row * N
- # pointers to probs
- probs_start = PROBS + cols + BLOCK * start_n
- # for start_n in range(0, tl.cdiv(N, BLOCK)): # need to change this
- probs = -tl.load(
- probs_start,
- mask=cols < N - (start_n * BLOCK),
- other=float("inf"),
- ).to(buffer_dtype)
- DIN = DIN + row * N + cols + BLOCK * start_n
- dout = tl.load(DPROBS + row * dprob_stride).to(buffer_dtype)
- # We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
- # and we have -log(p[k]) stored in PROBS, so this is easy
- if CLASS_INDICES:
- idx = tl.load(IDX + row)
- delta = ((start_n * BLOCK) + cols) == idx
- # write result in-place in PROBS
- if IGNORE_INDEX >= 0:
- use_class = idx == IGNORE_INDEX
- dout = dout * (1 - use_class)
- if LABEL_SMOOTHING:
- if WEIGHTS:
- weight_ptr = weight + cols + BLOCK * start_n
- full_weights_val = tl.load(weight_ptr, mask=cols < N - start_n * BLOCK, other=0.0).to(buffer_dtype)
- weights_val = tl.load(weight + idx)
- probs = probs / full_weights_val
- probs = tl.exp(probs)
- if WEIGHTS:
- weights_total = tl.load(WEIGHT_BUFFER + row)
- numerator_contrib = weights_val * (1.0 - smoothing_factor) * (probs - delta)
- mean_contrib = ((weights_total * probs) - (full_weights_val)) * smoothing_factor / N
- else:
- numerator_contrib = (1.0 - smoothing_factor) * (probs - delta)
- mean_contrib = (smoothing_factor * probs) - (smoothing_factor / N)
-
- din = (numerator_contrib + mean_contrib) * dout
-
- else:
- probs = tl.exp(probs)
- din = (probs - delta) * dout
- if WEIGHTS:
- weight_ptr = weight + idx
- weights_val = tl.load(weight_ptr)
- din = weights_val * din
- else:
- idx = tl.load(
- IDX + row * N + cols + BLOCK * start_n,
- mask=cols < N - start_n * BLOCK,
- other=0.0,
- ).to(buffer_dtype)
- full_weights_val = (1.0 - smoothing_factor) * idx + smoothing_factor / N
- weights_total = tl.load(WEIGHT_BUFFER + row)
- if WEIGHTS:
- weight_ptr = weight + cols + BLOCK * start_n
- weights_val = tl.load(weight_ptr, mask=cols < N - start_n * BLOCK, other=0.0).to(buffer_dtype)
- full_weights_val = weights_val * full_weights_val
- probs = probs / full_weights_val
- probs = tl.exp(probs.to(buffer_dtype))
- weighted_probs = probs * weights_total
- weighted_probs_per_class = weighted_probs - full_weights_val
- din = (weighted_probs_per_class) * dout
-
- tl.store(DIN, din.to(DIN.dtype.element_ty), mask=cols + BLOCK * start_n < N)
-
-
-class CrossEntropy(torch.autograd.Function):
- @staticmethod
- def forward(
- ctx,
- logits,
- indices,
- weight,
- ignore_index,
- reduction,
- label_smoothing,
- ):
- buffer_dtype = None
- # make sure we can use triton
- # assert (
- # indices.dtype == torch.int64
- # ), "Indices are expected to be of type long."
- assert weight is None or (len(weight.shape) == 1 and weight.shape[0] == logits.shape[-1])
- # make kernel
- if buffer_dtype is None:
- if logits.dtype in [torch.bfloat16, torch.float16]:
- buffer_dtype = torch.float32
- else:
- buffer_dtype = logits.dtype
- buffer_dtype_enum = _TORCH2DTYPE[buffer_dtype]
- device, dtype = logits.device, logits.dtype
- n_cols = logits.shape[-1]
- # run the kernel
- result = torch.empty((logits.shape[0],), dtype=dtype, device=device)
- # result = torch.empty_like(indices, dtype=dtype, device=device)
- neg_logprobs = torch.empty_like(logits, dtype=buffer_dtype, device=device)
- weights_buffer = torch.empty_like(result, dtype=buffer_dtype)
- grid = lambda opt: (logits.numel() // n_cols,)
- log_size_logits = int(math.log(math.prod(logits.shape) / n_cols))
- _forward[grid](
- logits,
- neg_logprobs,
- indices,
- result,
- weight,
- n_cols,
- weights_buffer,
- label_smoothing,
- log_size_logits,
- WEIGHTS=(weight is not None),
- CLASS_INDICES=(indices.dtype == torch.int64),
- LABEL_SMOOTHING=(label_smoothing > 0.0),
- IGNORE_INDEX=ignore_index,
- BUFFER_DTYPE=buffer_dtype_enum,
- )
- # save for backward
- ctx.save_for_backward(neg_logprobs, indices, weights_buffer)
- ctx.WEIGHT = weight
- ctx.label_smoothing = label_smoothing
- ctx.ignore_index = ignore_index
- ctx.reduction = reduction
- ctx.buffer_dtype = buffer_dtype_enum
- if reduction == "none":
- return result
- elif reduction == "sum":
- return result.sum(dim=0)
- elif reduction == "mean":
- if indices.dtype == torch.int64:
- denom = (indices != ignore_index).float()
- if weight is not None:
- class_weights = weight[indices]
- denom *= class_weights
- denom = denom.sum()
- else:
- denom = indices.shape[0]
- ctx.denom = denom
- return (result.sum(dim=0) / denom).to(dtype)
-
- @staticmethod
- def backward(ctx, dneg_logprobs):
- """We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
- so we initialize the gradient as neg_logprobs, so we can just exponentiate
- to get p[k], which is most of what we need... neg_logprobs will be
- modified in place to become the gradient we want
- """
- # load saved tensors
- reduction = ctx.reduction
- if reduction == "mean" or reduction == "sum":
- dneg_logprobs = dneg_logprobs.expand(1)
- neg_logprobs, indices, weights_buffer = ctx.saved_tensors
- din = torch.empty_like(neg_logprobs)
- weight = ctx.WEIGHT
- buffer_dtype = ctx.buffer_dtype
- # run the kernel
- # neg_logprobs will be modified in place to become our gradient:
- n_cols = neg_logprobs.shape[-1]
- grid = lambda opt: (
- neg_logprobs.numel() // n_cols,
- triton.cdiv(n_cols, opt["BLOCK"]),
- )
- log_size_logits = int(math.log(math.prod(neg_logprobs.shape) / n_cols))
- _backward[grid](
- neg_logprobs,
- indices,
- dneg_logprobs,
- dneg_logprobs.stride(0),
- din,
- weight,
- n_cols,
- weights_buffer,
- ctx.label_smoothing,
- log_size_logits,
- WEIGHTS=(weight is not None),
- CLASS_INDICES=(indices.dtype == torch.int64),
- LABEL_SMOOTHING=(ctx.label_smoothing > 0.0),
- IGNORE_INDEX=ctx.ignore_index,
- BUFFER_DTYPE=buffer_dtype,
- )
- if ctx.reduction == "mean":
- din /= ctx.denom
- return din, None, None, None, None, None, None
-
-
-def cross_entropy(
- input,
- target,
- weight=None,
- ignore_index=-100,
- reduction="mean",
- label_smoothing=0.0,
-):
- r"""
- Returns the Cross Entropy loss of input. If the target is class indcies
- then the ignore_index argument is applicable, while the label_smoothing argument
- is not. On the other hand, if the target is class probabilites, then the
- label_smoothing argument is applicable, while the ignore_index argument is not.
-
- Args:
- input: Tensor of shape (B, N)
- where B is the batch dim and N is the number of classes
- target: Int Tensor of shape (B,), min = 0, max = N-1 or
- Float Tensor of shape (B, N), rows sum to 1.0
- Int tensor of class labels.
- weight: Optional, Float Tensor of shape (N,)
- Weight to scale each class
- ignore_index: Int, which class label should be ignored
- reduction: String: ['none', 'sum', 'mean']
- label_smoothing: Float between 0 and 1
- """
- return CrossEntropy.apply(
- input,
- target,
- weight,
- ignore_index,
- reduction,
- label_smoothing,
- )
-
-
-# TODO: What is correct handling of ignore_index?
-def cross_entropy_impl(
- a,
- target,
- weight=None,
- size_average=None,
- ignore_index=-100,
- reduce=None,
- reduction="mean",
- label_smoothing=0.0,
-):
- loss = cross_entropy(a, target, weight, ignore_index, reduction, label_smoothing)
-
- return loss
-
-
-def cross_entropy_checker(
- a,
- /,
- target,
- weight=None,
- size_average=None,
- ignore_index=-100,
- reduce=None,
- reduction="mean",
- label_smoothing=0.0,
-) -> bool:
- if triton is None:
- return False
-
- torch_dtype = ltorch.to_torch_dtype(a.dtype)
- if torch_dtype not in (torch.float16, torch.bfloat16, torch.float32, torch.float64):
- return False
-
- # These arguments are deprecated and not supported
- if size_average is not None or reduce is not None:
- return False
-
- # We only support reduction of "sum", "mean" or "none"
- if reduction not in ["sum", "mean", "none"]:
- return False
-
- if len(a.shape) != 2:
- return False
-
- return True
-
-import thunder.torch as ltorch
+triton_ex: None | OperatorExecutor = None
+if triton_version is not None:
+ from thunder.executors.triton_crossentropy_impl import triton_ex as impl_ex
-ce = triton_ex.register_operator("triton_crossentropy", like=ltorch.cross_entropy, fn=cross_entropy_impl)
-triton_ex.register_implementation(ltorch.cross_entropy, ce, checker=cross_entropy_checker)
+ triton_ex = impl_ex
diff --git a/thunder/executors/triton_crossentropy_impl.py b/thunder/executors/triton_crossentropy_impl.py
new file mode 100644
index 0000000000..ff36fcb450
--- /dev/null
+++ b/thunder/executors/triton_crossentropy_impl.py
@@ -0,0 +1,631 @@
+import math
+from enum import Enum
+
+import torch
+
+from thunder.extend import OperatorExecutor, register_executor
+from thunder.executors import triton_utils
+
+# Requires triton 2.1 or greater
+min_triton_version = "2.1"
+
+triton_version: None | str = triton_utils.triton_version()
+TRITON_AVAILABLE: bool = triton_utils.is_triton_version_at_least(min_triton_version)
+assert (
+ TRITON_AVAILABLE
+), f"Trying to import a Triton executor, but it requires Triton version {min_triton_version} or greater, and the current Triton version is {triton_version}"
+
+triton_ex: OperatorExecutor = OperatorExecutor("triton", version=triton_version)
+register_executor(triton_ex)
+
+import triton # noqa: E402
+import triton.language as tl # noqa: E402
+
+# Temporarily borrowed from https://github.com/openai/triton
+FORWARD_NUM_STAGES = 1
+
+
+class TritonDtype(Enum):
+ kFP16 = 0
+ kBF16 = 1
+ kFP32 = 2
+ kFP64 = 3
+
+
+_TORCH2DTYPE = {
+ torch.float16: TritonDtype.kFP16,
+ torch.bfloat16: TritonDtype.kBF16,
+ torch.float32: TritonDtype.kFP32,
+ torch.float64: TritonDtype.kFP64,
+}
+_DTYPE2TRITON = {
+ TritonDtype.kFP16: tl.float16,
+ TritonDtype.kBF16: tl.bfloat16,
+ TritonDtype.kFP32: tl.float32,
+ TritonDtype.kFP64: tl.float64,
+}
+
+
+@triton.jit
+def _class_indices_forward(
+ LOGITS,
+ PROBS,
+ IDX,
+ LOSS,
+ weight,
+ N,
+ WEIGHT_BUFFER,
+ smoothing_factor,
+ log_size_logits,
+ WEIGHTS: tl.constexpr,
+ CLASS_INDICES: tl.constexpr,
+ LABEL_SMOOTHING: tl.constexpr,
+ IGNORE_INDEX: tl.constexpr,
+ BUFFER_DTYPE: tl.constexpr,
+ BLOCK: tl.constexpr,
+):
+ buffer_dtype = _DTYPE2TRITON[BUFFER_DTYPE.value]
+ row = tl.program_id(0)
+ cols = tl.arange(0, BLOCK)
+ logit_start_ptrs = LOGITS + row * N
+ logit_ptrs = logit_start_ptrs + cols
+ m_prev = -float("inf")
+ l_prev = 0.0
+ m_prev = m_prev.to(buffer_dtype)
+ l_prev = l_prev.to(buffer_dtype)
+
+ for start_n in range(0, tl.cdiv(N, BLOCK)):
+ row_logits = tl.load(
+ logit_ptrs,
+ mask=cols < N - (start_n * BLOCK),
+ other=-float("inf"),
+ ).to(buffer_dtype)
+
+ m_curr = tl.maximum(tl.max(row_logits, 0), m_prev)
+ l_prev *= tl.exp(m_prev - m_curr)
+ p = tl.exp(row_logits - m_curr)
+ l_curr = tl.sum(p, 0) + l_prev
+ l_prev = l_curr
+ m_prev = m_curr
+ logit_ptrs += BLOCK
+ logit_ptrs = logit_start_ptrs + cols
+ output_ptrs = PROBS + row * N + cols
+ WRIT_PROBS = PROBS + row * N + cols
+ if LABEL_SMOOTHING:
+ sum_total = 0.0
+ sum_total = sum_total.to(buffer_dtype)
+ weights_total = 0.0
+ weights_total = weights_total.to(buffer_dtype)
+ if WEIGHTS:
+ weight_ptr = weight + cols
+
+ l_prev_log = tl.log(l_prev)
+ # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
+ for start_n in range(0, tl.cdiv(N, BLOCK)):
+ row_logits = tl.load(
+ logit_ptrs,
+ mask=cols < N - start_n * BLOCK,
+ other=l_prev_log + m_prev,
+ ).to(buffer_dtype)
+ if LABEL_SMOOTHING and WEIGHTS:
+ full_weights_val = tl.load(weight_ptr, mask=cols < N - start_n * BLOCK, other=0.0)
+ weights_total += tl.sum(full_weights_val, 0)
+
+ row_minus_max = row_logits - m_prev
+ log_softmax = l_prev_log - row_minus_max
+
+ if LABEL_SMOOTHING and WEIGHTS:
+ log_softmax *= full_weights_val
+
+ if LABEL_SMOOTHING:
+ sum_total += tl.sum(log_softmax, 0)
+ # Store it back
+
+ tl.store(
+ WRIT_PROBS,
+ log_softmax,
+ mask=cols < N - start_n * BLOCK,
+ )
+ logit_ptrs += BLOCK
+ WRIT_PROBS += BLOCK
+ if LABEL_SMOOTHING and WEIGHTS:
+ weight_ptr += BLOCK
+
+ idx = tl.load(IDX + row)
+ use_class = 0.0
+ if IGNORE_INDEX >= 0:
+ use_class = idx == IGNORE_INDEX
+ READ_PROBS = PROBS + row * N + idx
+ tl.debug_barrier()
+ # write-back loss
+ probs = tl.load(READ_PROBS)
+ if WEIGHTS and not LABEL_SMOOTHING:
+ weight_ptr = weight + idx
+ weights_val = tl.load(weight_ptr)
+ probs = weights_val * probs
+ if LABEL_SMOOTHING:
+ tl.store(WEIGHT_BUFFER + row, weights_total)
+ probs = (1 - smoothing_factor) * probs + smoothing_factor * (sum_total) / N
+ probs = probs * (1.0 - use_class)
+
+ tl.store(LOSS + row, probs)
+
+
+@triton.jit
+def _class_probs_forward(
+ LOGITS,
+ PROBS,
+ IDX,
+ LOSS,
+ weight,
+ N,
+ WEIGHT_BUFFER,
+ smoothing_factor,
+ log_size_logits,
+ WEIGHTS: tl.constexpr,
+ CLASS_INDICES: tl.constexpr,
+ LABEL_SMOOTHING: tl.constexpr,
+ IGNORE_INDEX: tl.constexpr,
+ BUFFER_DTYPE: tl.constexpr,
+ BLOCK: tl.constexpr,
+):
+ buffer_dtype = _DTYPE2TRITON[BUFFER_DTYPE.value]
+ row = tl.program_id(0)
+ cols = tl.arange(0, BLOCK)
+ logit_start_ptrs = LOGITS + row * N
+ logit_ptrs = logit_start_ptrs + cols
+ m_prev = -float("inf")
+ l_prev = 0.0
+ m_prev = m_prev.to(buffer_dtype)
+ l_prev = l_prev.to(buffer_dtype)
+
+ for start_n in range(0, tl.cdiv(N, BLOCK)):
+ row_logits = tl.load(
+ logit_ptrs,
+ mask=cols < N - (start_n * BLOCK),
+ other=-float("inf"),
+ ).to(buffer_dtype)
+
+ m_curr = tl.maximum(tl.max(row_logits, 0), m_prev)
+ l_prev *= tl.exp(m_prev - m_curr)
+ p = tl.exp(row_logits - m_curr)
+ l_curr = tl.sum(p, 0) + l_prev
+ l_prev = l_curr
+ m_prev = m_curr
+ logit_ptrs += BLOCK
+ logit_ptrs = logit_start_ptrs + cols
+ output_ptrs = PROBS + row * N + cols
+ WRIT_PROBS = PROBS + row * N + cols
+
+ sum_total = 0.0
+ weights_total = 0.0
+ sum_total = sum_total.to(buffer_dtype)
+ weights_total = weights_total.to(buffer_dtype)
+ idx_ptr = IDX + row * N + cols
+ if WEIGHTS:
+ weight_ptr = weight + cols
+
+ l_prev_log = tl.log(l_prev)
+ # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
+ for start_n in range(0, tl.cdiv(N, BLOCK)):
+ row_logits = tl.load(
+ logit_ptrs,
+ mask=cols < N - start_n * BLOCK,
+ other=l_prev_log + m_prev,
+ ).to(buffer_dtype)
+ idx = tl.load(idx_ptr, mask=cols < N - start_n * BLOCK, other=0.0)
+ full_weights_val = (1.0 - smoothing_factor) * idx + smoothing_factor / N
+ if WEIGHTS:
+ weights_val = tl.load(weight_ptr, mask=cols < N - start_n * BLOCK, other=0.0)
+ full_weights_val = weights_val * full_weights_val
+ else:
+ full_weights_val = tl.where(cols < N - start_n * BLOCK, full_weights_val, 0.0)
+ weights_total += tl.sum(full_weights_val, 0)
+
+ row_minus_max = row_logits - m_prev
+ log_softmax = l_prev_log - row_minus_max
+
+ log_softmax *= full_weights_val
+ sum_total += tl.sum(log_softmax, 0)
+ # Store it back
+
+ tl.store(
+ WRIT_PROBS,
+ log_softmax,
+ mask=cols < N - start_n * BLOCK,
+ )
+ logit_ptrs += BLOCK
+ WRIT_PROBS += BLOCK
+ idx_ptr += BLOCK
+ if WEIGHTS:
+ weight_ptr += BLOCK
+
+ tl.store(WEIGHT_BUFFER + row, weights_total)
+ probs = sum_total
+
+ tl.store(LOSS + row, probs)
+
+
+@triton.autotune(
+ configs=[
+ # fmt: off
+ triton.Config({'BLOCK': 1024}, num_stages=FORWARD_NUM_STAGES, num_warps=1),
+ triton.Config({'BLOCK': 2048}, num_stages=FORWARD_NUM_STAGES, num_warps=8),
+ triton.Config({'BLOCK': 4096}, num_stages=FORWARD_NUM_STAGES, num_warps=8),
+ triton.Config({'BLOCK': 8192}, num_stages=FORWARD_NUM_STAGES, num_warps=16),
+ triton.Config({'BLOCK': 16384}, num_stages=FORWARD_NUM_STAGES, num_warps=16),
+ # fmt: on
+ ],
+ key=[
+ "N",
+ "CLASS_INDICES",
+ "log_size_logits",
+ "BUFFER_DTYPE",
+ ],
+)
+@triton.jit
+def _forward(
+ LOGITS,
+ PROBS,
+ IDX,
+ LOSS,
+ weight,
+ N,
+ WEIGHT_BUFFER,
+ smoothing_factor,
+ log_size_logits,
+ WEIGHTS: tl.constexpr,
+ CLASS_INDICES: tl.constexpr,
+ LABEL_SMOOTHING: tl.constexpr,
+ IGNORE_INDEX: tl.constexpr,
+ BUFFER_DTYPE: tl.constexpr,
+ BLOCK: tl.constexpr,
+):
+ if CLASS_INDICES:
+ _class_indices_forward(
+ LOGITS,
+ PROBS,
+ IDX,
+ LOSS,
+ weight,
+ N,
+ WEIGHT_BUFFER,
+ smoothing_factor,
+ log_size_logits,
+ WEIGHTS,
+ CLASS_INDICES,
+ LABEL_SMOOTHING,
+ IGNORE_INDEX,
+ BUFFER_DTYPE,
+ BLOCK,
+ )
+ else:
+ _class_probs_forward(
+ LOGITS,
+ PROBS,
+ IDX,
+ LOSS,
+ weight,
+ N,
+ WEIGHT_BUFFER,
+ smoothing_factor,
+ log_size_logits,
+ WEIGHTS,
+ CLASS_INDICES,
+ LABEL_SMOOTHING,
+ IGNORE_INDEX,
+ BUFFER_DTYPE,
+ BLOCK,
+ )
+
+
+@triton.autotune(
+ configs=[
+ # fmt: off
+ triton.Config({'BLOCK': 1024}, num_stages=1, num_warps=1),
+ triton.Config({'BLOCK': 2048}, num_stages=1, num_warps=8),
+ triton.Config({'BLOCK': 4096}, num_stages=1, num_warps=8),
+ triton.Config({'BLOCK': 8192}, num_stages=1, num_warps=16),
+ triton.Config({'BLOCK': 16384}, num_stages=1, num_warps=16),
+ # fmt: on
+ ],
+ key=[
+ "N",
+ "CLASS_INDICES",
+ "log_size_logits",
+ "BUFFER_DTYPE",
+ ],
+)
+@triton.jit
+def _backward(
+ PROBS,
+ IDX,
+ DPROBS,
+ dprob_stride,
+ DIN,
+ weight,
+ N,
+ WEIGHT_BUFFER,
+ smoothing_factor,
+ log_size_logits,
+ WEIGHTS: tl.constexpr,
+ CLASS_INDICES: tl.constexpr,
+ LABEL_SMOOTHING: tl.constexpr,
+ IGNORE_INDEX: tl.constexpr,
+ BUFFER_DTYPE: tl.constexpr,
+ BLOCK: tl.constexpr,
+):
+ buffer_dtype = _DTYPE2TRITON[BUFFER_DTYPE.value]
+ row = tl.program_id(0)
+ start_n = tl.program_id(1)
+ cols = tl.arange(0, BLOCK)
+ PROBS = PROBS + row * N
+ # pointers to probs
+ probs_start = PROBS + cols + BLOCK * start_n
+ # for start_n in range(0, tl.cdiv(N, BLOCK)): # need to change this
+ probs = -tl.load(
+ probs_start,
+ mask=cols < N - (start_n * BLOCK),
+ other=float("inf"),
+ ).to(buffer_dtype)
+ DIN = DIN + row * N + cols + BLOCK * start_n
+ dout = tl.load(DPROBS + row * dprob_stride).to(buffer_dtype)
+ # We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
+ # and we have -log(p[k]) stored in PROBS, so this is easy
+ if CLASS_INDICES:
+ idx = tl.load(IDX + row)
+ delta = ((start_n * BLOCK) + cols) == idx
+ # write result in-place in PROBS
+ if IGNORE_INDEX >= 0:
+ use_class = idx == IGNORE_INDEX
+ dout = dout * (1 - use_class)
+ if LABEL_SMOOTHING:
+ if WEIGHTS:
+ weight_ptr = weight + cols + BLOCK * start_n
+ full_weights_val = tl.load(weight_ptr, mask=cols < N - start_n * BLOCK, other=0.0).to(buffer_dtype)
+ weights_val = tl.load(weight + idx)
+ probs = probs / full_weights_val
+ probs = tl.exp(probs)
+ if WEIGHTS:
+ weights_total = tl.load(WEIGHT_BUFFER + row)
+ numerator_contrib = weights_val * (1.0 - smoothing_factor) * (probs - delta)
+ mean_contrib = ((weights_total * probs) - (full_weights_val)) * smoothing_factor / N
+ else:
+ numerator_contrib = (1.0 - smoothing_factor) * (probs - delta)
+ mean_contrib = (smoothing_factor * probs) - (smoothing_factor / N)
+
+ din = (numerator_contrib + mean_contrib) * dout
+
+ else:
+ probs = tl.exp(probs)
+ din = (probs - delta) * dout
+ if WEIGHTS:
+ weight_ptr = weight + idx
+ weights_val = tl.load(weight_ptr)
+ din = weights_val * din
+ else:
+ idx = tl.load(
+ IDX + row * N + cols + BLOCK * start_n,
+ mask=cols < N - start_n * BLOCK,
+ other=0.0,
+ ).to(buffer_dtype)
+ full_weights_val = (1.0 - smoothing_factor) * idx + smoothing_factor / N
+ weights_total = tl.load(WEIGHT_BUFFER + row)
+ if WEIGHTS:
+ weight_ptr = weight + cols + BLOCK * start_n
+ weights_val = tl.load(weight_ptr, mask=cols < N - start_n * BLOCK, other=0.0).to(buffer_dtype)
+ full_weights_val = weights_val * full_weights_val
+ probs = probs / full_weights_val
+ probs = tl.exp(probs.to(buffer_dtype))
+ weighted_probs = probs * weights_total
+ weighted_probs_per_class = weighted_probs - full_weights_val
+ din = (weighted_probs_per_class) * dout
+
+ tl.store(DIN, din.to(DIN.dtype.element_ty), mask=cols + BLOCK * start_n < N)
+
+
+class CrossEntropy(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ logits,
+ indices,
+ weight,
+ ignore_index,
+ reduction,
+ label_smoothing,
+ ):
+ buffer_dtype = None
+ # make sure we can use triton
+ # assert (
+ # indices.dtype == torch.int64
+ # ), "Indices are expected to be of type long."
+ assert weight is None or (len(weight.shape) == 1 and weight.shape[0] == logits.shape[-1])
+ # make kernel
+ if buffer_dtype is None:
+ if logits.dtype in [torch.bfloat16, torch.float16]:
+ buffer_dtype = torch.float32
+ else:
+ buffer_dtype = logits.dtype
+ buffer_dtype_enum = _TORCH2DTYPE[buffer_dtype]
+ device, dtype = logits.device, logits.dtype
+ n_cols = logits.shape[-1]
+ # run the kernel
+ result = torch.empty((logits.shape[0],), dtype=dtype, device=device)
+ # result = torch.empty_like(indices, dtype=dtype, device=device)
+ neg_logprobs = torch.empty_like(logits, dtype=buffer_dtype, device=device)
+ weights_buffer = torch.empty_like(result, dtype=buffer_dtype)
+ grid = lambda opt: (logits.numel() // n_cols,)
+ log_size_logits = int(math.log(math.prod(logits.shape) / n_cols))
+ _forward[grid](
+ logits,
+ neg_logprobs,
+ indices,
+ result,
+ weight,
+ n_cols,
+ weights_buffer,
+ label_smoothing,
+ log_size_logits,
+ WEIGHTS=(weight is not None),
+ CLASS_INDICES=(indices.dtype == torch.int64),
+ LABEL_SMOOTHING=(label_smoothing > 0.0),
+ IGNORE_INDEX=ignore_index,
+ BUFFER_DTYPE=buffer_dtype_enum,
+ )
+ # save for backward
+ ctx.save_for_backward(neg_logprobs, indices, weights_buffer)
+ ctx.WEIGHT = weight
+ ctx.label_smoothing = label_smoothing
+ ctx.ignore_index = ignore_index
+ ctx.reduction = reduction
+ ctx.buffer_dtype = buffer_dtype_enum
+ if reduction == "none":
+ return result
+ elif reduction == "sum":
+ return result.sum(dim=0)
+ elif reduction == "mean":
+ if indices.dtype == torch.int64:
+ denom = (indices != ignore_index).float()
+ if weight is not None:
+ class_weights = weight[indices]
+ denom *= class_weights
+ denom = denom.sum()
+ else:
+ denom = indices.shape[0]
+ ctx.denom = denom
+ return (result.sum(dim=0) / denom).to(dtype)
+
+ @staticmethod
+ def backward(ctx, dneg_logprobs):
+ """We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
+ so we initialize the gradient as neg_logprobs, so we can just exponentiate
+ to get p[k], which is most of what we need... neg_logprobs will be
+ modified in place to become the gradient we want
+ """
+ # load saved tensors
+ reduction = ctx.reduction
+ if reduction == "mean" or reduction == "sum":
+ dneg_logprobs = dneg_logprobs.expand(1)
+ neg_logprobs, indices, weights_buffer = ctx.saved_tensors
+ din = torch.empty_like(neg_logprobs)
+ weight = ctx.WEIGHT
+ buffer_dtype = ctx.buffer_dtype
+ # run the kernel
+ # neg_logprobs will be modified in place to become our gradient:
+ n_cols = neg_logprobs.shape[-1]
+ grid = lambda opt: (
+ neg_logprobs.numel() // n_cols,
+ triton.cdiv(n_cols, opt["BLOCK"]),
+ )
+ log_size_logits = int(math.log(math.prod(neg_logprobs.shape) / n_cols))
+ _backward[grid](
+ neg_logprobs,
+ indices,
+ dneg_logprobs,
+ dneg_logprobs.stride(0),
+ din,
+ weight,
+ n_cols,
+ weights_buffer,
+ ctx.label_smoothing,
+ log_size_logits,
+ WEIGHTS=(weight is not None),
+ CLASS_INDICES=(indices.dtype == torch.int64),
+ LABEL_SMOOTHING=(ctx.label_smoothing > 0.0),
+ IGNORE_INDEX=ctx.ignore_index,
+ BUFFER_DTYPE=buffer_dtype,
+ )
+ if ctx.reduction == "mean":
+ din /= ctx.denom
+ return din, None, None, None, None, None, None
+
+
+def cross_entropy(
+ input,
+ target,
+ weight=None,
+ ignore_index=-100,
+ reduction="mean",
+ label_smoothing=0.0,
+):
+ r"""
+ Returns the Cross Entropy loss of input. If the target is class indcies
+ then the ignore_index argument is applicable, while the label_smoothing argument
+ is not. On the other hand, if the target is class probabilites, then the
+ label_smoothing argument is applicable, while the ignore_index argument is not.
+
+ Args:
+ input: Tensor of shape (B, N)
+ where B is the batch dim and N is the number of classes
+ target: Int Tensor of shape (B,), min = 0, max = N-1 or
+ Float Tensor of shape (B, N), rows sum to 1.0
+ Int tensor of class labels.
+ weight: Optional, Float Tensor of shape (N,)
+ Weight to scale each class
+ ignore_index: Int, which class label should be ignored
+ reduction: String: ['none', 'sum', 'mean']
+ label_smoothing: Float between 0 and 1
+ """
+ return CrossEntropy.apply(
+ input,
+ target,
+ weight,
+ ignore_index,
+ reduction,
+ label_smoothing,
+ )
+
+
+# TODO: What is correct handling of ignore_index?
+def cross_entropy_impl(
+ a,
+ target,
+ weight=None,
+ size_average=None,
+ ignore_index=-100,
+ reduce=None,
+ reduction="mean",
+ label_smoothing=0.0,
+):
+ loss = cross_entropy(a, target, weight, ignore_index, reduction, label_smoothing)
+
+ return loss
+
+
+def cross_entropy_checker(
+ a,
+ /,
+ target,
+ weight=None,
+ size_average=None,
+ ignore_index=-100,
+ reduce=None,
+ reduction="mean",
+ label_smoothing=0.0,
+) -> bool:
+ if triton is None:
+ return False
+
+ torch_dtype = ltorch.to_torch_dtype(a.dtype)
+ if torch_dtype not in (torch.float16, torch.bfloat16, torch.float32, torch.float64):
+ return False
+
+ # These arguments are deprecated and not supported
+ if size_average is not None or reduce is not None:
+ return False
+
+ # We only support reduction of "sum", "mean" or "none"
+ if reduction not in ["sum", "mean", "none"]:
+ return False
+
+ if len(a.shape) != 2:
+ return False
+
+ return True
+
+
+import thunder.torch as ltorch
+
+ce = triton_ex.register_operator("triton_crossentropy", like=ltorch.cross_entropy, fn=cross_entropy_impl)
+triton_ex.register_implementation(ltorch.cross_entropy, ce, checker=cross_entropy_checker)
diff --git a/thunder/extend/__init__.py b/thunder/extend/__init__.py
index 2acdcf6427..a1fd0c8361 100644
--- a/thunder/extend/__init__.py
+++ b/thunder/extend/__init__.py
@@ -306,13 +306,9 @@ def get_all_executors() -> tuple[Executor]:
torch_compile,
torchex,
transformer_engineex,
+ triton_crossentropy,
)
- if torch.cuda.is_available():
- # raise an error when a dependency is not available at import time
- # TODO: this should only happen at runtime
- from thunder.executors import triton_crossentropy
-
return tuple(_executor_map.values())
diff --git a/thunder/tests/distributed/test_ddp.py b/thunder/tests/distributed/test_ddp.py
index c321a01394..3562a10a5c 100644
--- a/thunder/tests/distributed/test_ddp.py
+++ b/thunder/tests/distributed/test_ddp.py
@@ -843,6 +843,20 @@ def check_inflight_allgather_number(trc, n: int, is_bucket: bool):
aft_trc = limit_in_flight_allgathers(bwd_trc, i, is_bucketing)
check_inflight_allgather_number(aft_trc, i, is_bucketing)
+ def test_ddp_model_as_argument(self):
+ # Sanity test to make sure passing model as argument to
+ # thunder.jit with `ddp` compiles.
+ device = torch.device("cuda", self.rank)
+ model = torch.nn.Linear(5, 10, bias=False, device=device)
+ x = torch.randn(2, 5, device=device)
+
+ def fwd_loss(m, x):
+ return m(x).sum()
+
+ model = thunder.distributed.ddp(model)
+ fwd_loss = thunder.jit(fwd_loss)
+ fwd_loss(model, x)
+
common_utils.instantiate_parametrized_tests(CompileDDPTest)
diff --git a/thunder/tests/lit_gpt_model.py b/thunder/tests/litgpt_model.py
similarity index 82%
rename from thunder/tests/lit_gpt_model.py
rename to thunder/tests/litgpt_model.py
index 57a85089bc..23b51545a0 100644
--- a/thunder/tests/lit_gpt_model.py
+++ b/thunder/tests/litgpt_model.py
@@ -1,4 +1,4 @@
-"""Taken from https://github.com/Lightning-AI/lit-gpt/blob/main/lit_gpt/model.py"""
+"""Taken from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py"""
import torch
import torch.nn as nn
@@ -18,9 +18,9 @@
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
- _norm_class="RMSNorm",
+ norm_class_name="RMSNorm",
norm_eps=1e-6,
- _mlp_class="LLaMAMLP",
+ mlp_class_name="LLaMAMLP",
intermediate_size=1376,
),
dict(
@@ -34,8 +34,8 @@
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
- _norm_class="RMSNorm",
- _mlp_class="LLaMAMLP",
+ norm_class_name="RMSNorm",
+ mlp_class_name="LLaMAMLP",
intermediate_size=11008,
rope_condense_ratio=4,
),
@@ -49,8 +49,8 @@
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
- _norm_class="RMSNorm",
- _mlp_class="LLaMAMLP",
+ norm_class_name="RMSNorm",
+ mlp_class_name="LLaMAMLP",
intermediate_size=1376,
),
dict(
@@ -87,9 +87,9 @@
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
- _norm_class="RMSNorm",
+ norm_class_name="RMSNorm",
norm_eps=1e-05,
- _mlp_class="LLaMAMLP",
+ mlp_class_name="LLaMAMLP",
intermediate_size=1376,
rope_base=1000000,
),
@@ -104,9 +104,9 @@
n_query_groups=8,
parallel_residual=False,
bias=False,
- _norm_class="RMSNorm",
+ norm_class_name="RMSNorm",
norm_eps=1e-05,
- _mlp_class="LLaMAMoE",
+ mlp_class_name="LLaMAMoE",
intermediate_size=224,
rope_base=1000000,
n_expert=8,
@@ -150,21 +150,20 @@ def reset_parameters(self) -> None:
torch.nn.init.zeros_(self.v)
-import lit_gpt
-import lit_gpt.rmsnorm
+import litgpt
# override for operator workarounds
-lit_gpt.model.KVCache = OverridenKVCache
+litgpt.model.KVCache = OverridenKVCache
# add the testing configurations
-lit_gpt.config.name_to_config.update(name_to_config)
-name_to_config.update(lit_gpt.config.name_to_config)
+litgpt.config.name_to_config.update(name_to_config)
+name_to_config.update(litgpt.config.name_to_config)
# manually expose for backwards compatibility
-Config = lit_gpt.Config
-GPT = lit_gpt.GPT
-RMSNorm = lit_gpt.rmsnorm.RMSNorm
-CausalSelfAttention = lit_gpt.model.CausalSelfAttention
-LLaMAMLP = lit_gpt.model.LLaMAMLP
-build_rope_cache = lit_gpt.model.build_rope_cache
-apply_rope = lit_gpt.model.apply_rope
-Block = lit_gpt.model.Block
+Config = litgpt.Config
+GPT = litgpt.GPT
+RMSNorm = litgpt.model.RMSNorm
+CausalSelfAttention = litgpt.model.CausalSelfAttention
+LLaMAMLP = litgpt.model.LLaMAMLP
+build_rope_cache = litgpt.model.build_rope_cache
+apply_rope = litgpt.model.apply_rope
+Block = litgpt.model.Block
diff --git a/thunder/tests/test_cudnn_executor.py b/thunder/tests/test_cudnn_executor.py
index 4128e02914..ab985aecbf 100644
--- a/thunder/tests/test_cudnn_executor.py
+++ b/thunder/tests/test_cudnn_executor.py
@@ -1,4 +1,3 @@
-import random
from functools import partial
from typing import Any
@@ -35,24 +34,19 @@ def grad_scaled_dot_product_attention_reference_generator(op, device, dtype, req
n_head = 2
N = 8 # batch size
-
- # TODO: multiple of 8 seems to produce NaNs
- L = random.randint(1, 10) * 64 # query's sequence length
-
- alignment_factor = 8
- S = random.randint(1, 10) * alignment_factor # key/value's sequence length
- E = random.randint(8, 16) * alignment_factor # query/key's embedding size
- Ev = random.randint(8, 16) * alignment_factor # value's embedding size
+ L = 640 # query's sequence length
+ S = 80 # key/value's sequence length
+ E = 128 # query/key's embedding size
+ Ev = 64 # value's embedding size
# 4-dim (multiheaded) causal cases
q, k, v = make(N, n_head, L, E), make(N, n_head, S, E), make(N, n_head, S, Ev)
yield SampleInput(q, k, v, None, dropout_p=0.0, is_causal=True)
- # TODO: cudnnex seems to have a few mismatches. Will be enabled in a later PR.
# Non-contiguous input tensor case
- nq = make(N, n_head, L, E).permute(0, 1, 3, 2)
- nk = make(N, n_head, L, E).permute(0, 1, 3, 2)
- nv = make(N, n_head, L, E).permute(0, 1, 3, 2)
+ nq = make(N, n_head, E, L).permute(0, 1, 3, 2)
+ nk = make(N, n_head, E, S).permute(0, 1, 3, 2)
+ nv = make(N, n_head, Ev, S).permute(0, 1, 3, 2)
yield SampleInput(nq, nk, nv, None, dropout_p=0.0, is_causal=False)
# Test the scale factor which was added in torch 2.1
diff --git a/thunder/tests/test_extend.py b/thunder/tests/test_extend.py
index b7dd300d12..592943d9e7 100644
--- a/thunder/tests/test_extend.py
+++ b/thunder/tests/test_extend.py
@@ -10,6 +10,7 @@
from thunder.core.proxies import TensorProxy
from thunder.core.transforms import grad, get_grad, put_grads
from thunder.extend import OperatorExecutor, register_executor, deregister_executor, get_all_executors
+from lightning_utilities.core.imports import package_available
def test_extend_core():
@@ -127,7 +128,9 @@ def test_get_all_executors_includes_all_native_executors():
"python",
"transformer_engine",
}
- actual.discard("triton") # remove when triton can always be imported
+ if package_available("triton"):
+ # `triton` maybe installed on a system without GPU.
+ expected.update({"triton"})
if torch.cuda.is_available():
expected.update({"nvfuser"})
assert actual == expected
diff --git a/thunder/tests/test_interpreter.py b/thunder/tests/test_interpreter.py
index 447647b8bb..3f0dfb0a94 100644
--- a/thunder/tests/test_interpreter.py
+++ b/thunder/tests/test_interpreter.py
@@ -1337,7 +1337,7 @@ def foo(a):
def foo():
# test relative import
- from .lit_gpt_model import Config
+ from .litgpt_model import Config
return Config
@@ -1345,18 +1345,18 @@ def foo():
def foo():
# test relative import
- from . import lit_gpt_model
+ from . import litgpt_model
- return lit_gpt_model.Config
+ return litgpt_model.Config
assert jit(foo)() is foo()
# reload is implemented using exec of the module
- from . import lit_gpt_model
+ from . import litgpt_model
import importlib
- importlib.reload(lit_gpt_model)
- assert hasattr(lit_gpt_model, "GPT")
+ importlib.reload(litgpt_model)
+ assert hasattr(litgpt_model, "GPT")
def test_locals_lookaside(jit):
@@ -3071,7 +3071,7 @@ def test_nanogpt(jit):
def test_litgpt(jit):
from thunder.benchmarks import LitGPTBenchmark
- from thunder.tests.lit_gpt_model import Config
+ from thunder.tests.litgpt_model import Config
cfg: Config = Config.from_name("gpt-neox-like")
bench = LitGPTBenchmark(config=cfg, device="cpu", dtype=torch.bfloat16, requires_grad=True)
diff --git a/thunder/tests/test_jit_general.py b/thunder/tests/test_jit_general.py
index 28fc8a54fc..3671bd12e7 100644
--- a/thunder/tests/test_jit_general.py
+++ b/thunder/tests/test_jit_general.py
@@ -16,7 +16,7 @@
import thunder
from thunder.core.interpreter import is_jitting, InterpreterError
-from thunder.tests import lit_gpt_model
+from thunder.tests import litgpt_model
import thunder.clang as clang
from thunder.core.options import INTERPRETATION_OPTIONS, CACHE_OPTIONS
import thunder.torch as ltorch
@@ -505,7 +505,7 @@ def h(d, c):
def test_litgpt():
from thunder.benchmarks import LitGPTBenchmark
- from thunder.tests.lit_gpt_model import Config
+ from thunder.tests.litgpt_model import Config
cfg: Config = Config.from_name("gpt-neox-like")
bench = LitGPTBenchmark(config=cfg, device="cpu", dtype=torch.bfloat16, requires_grad=True)
@@ -627,16 +627,16 @@ def test_litgpt_variants(name, device):
device = torch.device(device)
x = torch.randint(0, 200, (5, 5), device=device)
- config = lit_gpt_model.Config.from_name(name)
+ config = litgpt_model.Config.from_name(name)
with device:
- reference = lit_gpt_model.GPT(config)
+ reference = litgpt_model.GPT(config)
expected_logits = reference(x)
expected_logits.sum().backward()
with device:
- model = lit_gpt_model.GPT(config)
+ model = litgpt_model.GPT(config)
model.load_state_dict(reference.state_dict())
tom = thunder.jit(model, executors=nvfuserex if device.type == "cuda" else torchex)
actual_logits = tom(x)
@@ -677,10 +677,10 @@ def test_litgpt_variants_kvcache(name, device):
device = torch.device(device)
x = torch.randint(0, 200, (1, 2), device=device)
- config = lit_gpt_model.Config.from_name(name)
+ config = litgpt_model.Config.from_name(name)
with device:
- model = lit_gpt_model.GPT(config)
+ model = litgpt_model.GPT(config)
model.max_seq_length = 3
for p in model.parameters():