diff --git a/.azure/gpu-tests.yml b/.azure/gpu-tests.yml index 22ba01eddc..1bf99b0aaa 100644 --- a/.azure/gpu-tests.yml +++ b/.azure/gpu-tests.yml @@ -65,8 +65,14 @@ jobs: # drop pt from requirements so not to interfere with the existing one bash .azure/remove-torch-lines.sh requirements/base.txt cat requirements/base.txt + # double check on test requirements pip install -r requirements/test.txt + + # https://docs.codecov.com/docs/codecov-uploader + curl -Os https://uploader.codecov.io/latest/linux/codecov + chmod +x codecov + # install this package python setup.py develop displayName: 'Install package & ...' @@ -85,6 +91,12 @@ jobs: --durations=250 \ --numprocesses=9 \ --ignore=thunder/tests/distributed --ignore=thunder/tests/test_networks.py + # compile coverage results + python -m coverage report + python -m coverage xml + # upload to codecov + ./codecov --token=$(CODECOV_TOKEN) --commit=$(Build.SourceVersion) \ + --flags=gpu,pytest,regular --name="GPU-coverage" --env=linux,azure condition: ne(variables['testing'], 'distributed') displayName: 'Testing: regular' @@ -95,6 +107,12 @@ jobs: thunder/tests/test_networks.py \ -m "not standalone" \ -v --random-order-seed=42 --durations=0 --numprocesses=3 + # compile coverage results + python -m coverage report + python -m coverage xml + # upload to codecov + ./codecov --token=$(CODECOV_TOKEN) --commit=$(Build.SourceVersion) \ + --flags=gpu,pytest,networks --name="GPU-coverage" --env=linux,azure condition: ne(variables['testing'], 'distributed') displayName: 'Testing: networks' @@ -108,6 +126,12 @@ jobs: - bash: | # run all found tests in given past as standalone bash scripts/run_standalone_tests.sh "thunder/tests/distributed" + # compile coverage results + python -m coverage report + python -m coverage xml + # upload to codecov + ./codecov --token=$(CODECOV_TOKEN) --commit=$(Build.SourceVersion) \ + --flags=gpu,pytest,distributed --name="GPU-coverage" --env=linux,azure condition: eq(variables['testing'], 'distributed') displayName: 'Testing: distributed' diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index a9d21a849b..1d016662eb 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,10 +1,13 @@ -# Before submitting +
+ Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos and docs improvements) - [ ] Did you read the [contributor guideline](https://github.com/Lightning-AI/pytorch-lightning/blob/main/.github/CONTRIBUTING.md), Pull Request section? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? +
+ ## What does this PR do? Fixes # (issue). diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index cbd8fb2aa6..310279fbbd 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -114,15 +114,15 @@ jobs: coverage report coverage xml - #- name: Upload coverage to Codecov - # uses: codecov/codecov-action@v3 - # with: - # token: ${{ secrets.CODECOV_TOKEN }} - # file: ./coverage.xml - # flags: unittests - # env_vars: OS,PYTHON - # name: codecov-umbrella - # fail_ci_if_error: false + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + token: ${{ secrets.CODECOV_TOKEN }} + file: ./coverage.xml + flags: unittests + env_vars: OS,PYTHON + name: codecov-umbrella + fail_ci_if_error: false testing-guardian: diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml index d17ae79eef..b8d87e466d 100644 --- a/.github/workflows/docs-build.yml +++ b/.github/workflows/docs-build.yml @@ -14,7 +14,7 @@ defaults: shell: bash jobs: - build-docs: + docs-make: uses: Lightning-AI/utilities/.github/workflows/check-docs.yml@v0.11.0 with: python-version: "3.10" @@ -28,7 +28,7 @@ jobs: env: GCP_TARGET: "gs://lightning-docs-thunder" steps: - - uses: actions/download-artifact@v4 + - uses: actions/download-artifact@v3 with: name: docs-html-${{ github.sha }} path: docs/build/ @@ -50,7 +50,7 @@ jobs: # Uploading docs to GCS, so they can be served on lightning.ai - name: Upload docs/thunder/latest to GCS 🪣 - if: github.ref == 'refs/heads/master' + if: github.ref == 'refs/heads/main' run: gsutil -m rsync -d -R docs/build/html/ ${GCP_TARGET}/latest # Uploading docs to GCS, so they can be served on lightning.ai diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7b7b8d1f6c..f16e1ef98f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,6 +17,9 @@ repos: - id: check-toml - id: check-json - id: check-added-large-files + with: + maxkb: 250 + enforce-all: true - id: check-docstring-first - id: detect-private-key diff --git a/README.md b/README.md index f89dce73b0..ce5399d570 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@
-Thunder +Thunder +Thunder

diff --git a/dockers/ubuntu-cuda/Dockerfile b/dockers/ubuntu-cuda/Dockerfile index e815d827f6..2213921897 100644 --- a/dockers/ubuntu-cuda/Dockerfile +++ b/dockers/ubuntu-cuda/Dockerfile @@ -24,6 +24,7 @@ ARG CUDNN_FRONTEND_CHECKOUT="v1.1.0" ARG PYTHON_VERSION="3.10" ARG TORCH_VERSION="2.2.1" ARG TRITON_VERSION="2.2.0" +ARG TORCH_INSTALL="stable" SHELL ["/bin/bash", "-c"] # https://techoverflow.net/2019/05/18/how-to-fix-configuring-tzdata-interactive-input-when-building-docker-images/ @@ -96,7 +97,7 @@ ENV \ TORCH_CUDA_ARCH_LIST="8.0" \ CUDA_SELECT_NVCC_ARCH_FLAGS="8.0" -ARG TORCH_INSTALL="wheel" +ARG TORCH_INSTALL RUN \ if [ "${TORCH_INSTALL}" == "source" ]; then \ @@ -122,15 +123,26 @@ RUN \ --index-url="https://download.pytorch.org/whl/cu${CUDA_VERSION_MM//'.'/''}"; \ fi +ARG TORCH_INSTALL + RUN \ - # building nvFuser from source - git clone https://github.com/NVIDIA/Fuser.git && \ - cd Fuser && \ - git submodule update --init --recursive && \ - pip install -r requirements.txt && \ - python setup.py install --no-test --no-benchmark && \ - cd .. && \ - rm -rf Fuser + if [ "${TORCH_INSTALL}" == "source" ]; then \ + # building nvFuser from source + git clone https://github.com/NVIDIA/Fuser.git && \ + cd Fuser && \ + git submodule update --init --recursive && \ + pip install -r requirements.txt && \ + python setup.py install --no-test --no-benchmark && \ + cd .. && \ + rm -rf Fuser ; \ + elif [ "${TORCH_INSTALL}" == "test" ]; then \ + echo "Not supported option" ; \ + else \ + # installing pytorch from wheels \ + CUDA_VERSION_MM=${CUDA_VERSION%.*} && \ + TORCH_VERSION_MM=${TORCH_VERSION%.*} && \ + pip install -U "nvfuser-cu${CUDA_VERSION_MM/./}-torch${TORCH_VERSION_MM/./}" ; \ + fi RUN \ ls -lh requirements/ && \ diff --git a/docs/source/_static/images/LightningThunderDarkModewByline.png b/docs/source/_static/images/LightningThunderDarkModewByline.png new file mode 100644 index 0000000000..310739641c Binary files /dev/null and b/docs/source/_static/images/LightningThunderDarkModewByline.png differ diff --git a/docs/source/_static/images/LightningThunderLightModewByline.png b/docs/source/_static/images/LightningThunderLightModewByline.png new file mode 100644 index 0000000000..7effda5ec9 Binary files /dev/null and b/docs/source/_static/images/LightningThunderLightModewByline.png differ diff --git a/docs/source/_static/images/normalized_training_throughput_zero2.png b/docs/source/_static/images/normalized_training_throughput_zero2.png index be6e5888c3..60dc0725a9 100644 Binary files a/docs/source/_static/images/normalized_training_throughput_zero2.png and b/docs/source/_static/images/normalized_training_throughput_zero2.png differ diff --git a/docs/source/_static/images/training_throughput_single.png b/docs/source/_static/images/training_throughput_single.png index 6c0a7029a4..ad66af29db 100644 Binary files a/docs/source/_static/images/training_throughput_single.png and b/docs/source/_static/images/training_throughput_single.png differ diff --git a/docs/source/conf.py b/docs/source/conf.py index 052fa2437c..598dc42053 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -92,7 +92,7 @@ def _transform_changelog(path_in: str, path_out: str) -> None: "sphinx.ext.linkcode", "sphinx.ext.autosummary", "sphinx.ext.napoleon", - "sphinx.ext.imgmath", + "sphinx.ext.mathjax", "myst_parser", "nbsphinx", "sphinx_autodoc_typehints", @@ -209,6 +209,11 @@ def _transform_changelog(path_in: str, path_out: str) -> None: (master_doc, project + ".tex", project + " Documentation", author, "manual"), ] +# MathJax configuration +mathjax3_config = { + "tex": {"packages": {"[+]": ["ams", "newcommand", "configMacros"]}}, +} + # -- Options for manual page output ------------------------------------------ # One entry per manual page. List of tuples diff --git a/notebooks/dev_tutorials/fsdp_tutorial.ipynb b/notebooks/dev_tutorials/fsdp_tutorial.ipynb index a4f61b47c3..b8cef2a2ff 100644 --- a/notebooks/dev_tutorials/fsdp_tutorial.ipynb +++ b/notebooks/dev_tutorials/fsdp_tutorial.ipynb @@ -1764,7 +1764,7 @@ "%%writefile thunder_fsdp_simple_example.py\n", "\n", "# imports\n", - "from thunder.tests.lit_gpt_model import GPT, Config\n", + "from thunder.tests.litgpt_model import GPT, Config\n", "import torch\n", "import torch.distributed\n", "import thunder\n", diff --git a/notebooks/zero_to_thunder.ipynb b/notebooks/zero_to_thunder.ipynb index a1a888cc72..ffd6b5fe1c 100644 --- a/notebooks/zero_to_thunder.ipynb +++ b/notebooks/zero_to_thunder.ipynb @@ -312,8 +312,8 @@ } ], "source": [ - "from lit_gpt import GPT\n", - "from thunder.tests.lit_gpt_model import Config\n", + "from litgpt import GPT\n", + "from thunder.tests.litgpt_model import Config\n", "cfg = Config.from_name('Llama-2-7b-hf')\n", "cfg.n_layer = 16 # fewer layers\n", "torch.set_default_dtype(torch.bfloat16)\n", @@ -3326,7 +3326,7 @@ ], "source": [ "%%writefile zero_to_thunder_fsdp_simple_example.py\n", - "from thunder.tests.lit_gpt_model import GPT, Config\n", + "from thunder.tests.litgpt_model import GPT, Config\n", "import os\n", "import torch, torch.distributed\n", "import thunder, thunder.distributed\n", @@ -3470,7 +3470,7 @@ }, "outputs": [], "source": [ - "import lit_gpt\n", + "import litgpt\n", "def apply_rope_copy(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:\n", " head_size = x.size(-1)\n", " x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)\n", @@ -3493,7 +3493,7 @@ "\n", "Say we have a function `apply_rope` applying the RoPE transformation in PyTorch.\n", "\n", - "In thunder, we define a *meta* function that only defines the metadata (like shapes) of outputs and the actual implementation for each operator and then register the pair with our executor using the `register_operator` function and tell it to use the new symbol instead of the original function `lit_gpt.model.apply_rope`.\n" + "In thunder, we define a *meta* function that only defines the metadata (like shapes) of outputs and the actual implementation for each operator and then register the pair with our executor using the `register_operator` function and tell it to use the new symbol instead of the original function `litgpt.model.apply_rope`.\n" ] }, { @@ -3504,17 +3504,17 @@ "outputs": [], "source": [ "import torch, thunder\n", - "from thunder.tests.lit_gpt_model import GPT\n", + "from thunder.tests.litgpt_model import GPT\n", "from thunder import TensorProxy\n", "\n", "def apply_rope_impl(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:\n", - " return lit_gpt.model.apply_rope(x, cos, sin)\n", + " return litgpt.model.apply_rope(x, cos, sin)\n", "\n", "def apply_rope_meta(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> TensorProxy:\n", " return TensorProxy(like=x)\n", "\n", "apply_rope = my_ex.register_operator('apply_rope', like=apply_rope_meta, fn=apply_rope_impl,\n", - " replaces=lit_gpt.model.apply_rope)" + " replaces=litgpt.model.apply_rope)" ] }, { @@ -3569,7 +3569,7 @@ "with torch.device('cuda'): m = GPT.from_name('llama2-like'); Q = torch.randn(2, 128, 4096, 16)\n", "\n", "def test_apply_rope(x, m):\n", - " return lit_gpt.model.apply_rope(x, m.cos, m.sin)\n", + " return litgpt.model.apply_rope(x, m.cos, m.sin)\n", "\n", "thunder_apply_rope = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors()) \n", "\n", diff --git a/requirements/notebooks.txt b/requirements/notebooks.txt index 47a14902ca..7c37419cd6 100644 --- a/requirements/notebooks.txt +++ b/requirements/notebooks.txt @@ -1 +1,3 @@ ipython[all] ==8.22.2 + +litgpt @ git+https://github.com/Lightning-AI/lit-gpt@24d5eba1724c953b7506edc041a7da1ce226c129 diff --git a/requirements/test.txt b/requirements/test.txt index c882272833..621ebaa7c5 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -11,7 +11,7 @@ expecttest ==0.2.1 # for test_ddp.py hypothesis ==6.99.10 # for test_ddp.py numpy # for test_ops.py einops # for test_einops.py -lit_gpt @ git+https://github.com/Lightning-AI/lit-gpt@f241d94df59d82b2017bfdcd3800ac8779eb45f5 +litgpt @ git+https://github.com/Lightning-AI/lit-gpt@24d5eba1724c953b7506edc041a7da1ce226c129 absl-py # thunder/benchmarks/test_benchmark_litgpt.py pandas # thunder/benchmarks/test_benchmark_litgpt.py xlsxwriter # thunder/benchmarks/test_benchmark_litgpt.py diff --git a/thunder/benchmarks/__init__.py b/thunder/benchmarks/__init__.py index e9ba261c45..f4b0f2ac03 100644 --- a/thunder/benchmarks/__init__.py +++ b/thunder/benchmarks/__init__.py @@ -24,9 +24,9 @@ import thunder.core.devices as Devices from thunder.core.transforms import grad, clear_grads, populate_grads import thunder.executors as executors -from thunder.tests import nanogpt_model, hf_bart_self_attn, lit_gpt_model +from thunder.tests import nanogpt_model, hf_bart_self_attn, litgpt_model from thunder.tests.make_tensor import make_tensor, make_tensor_like -from thunder.tests.lit_gpt_model import Config as LitGPTConfig +from thunder.tests.litgpt_model import Config as LitGPTConfig # List of all benchmarks benchmarks: list = [] @@ -1875,7 +1875,7 @@ class LlamaMLPBenchmark(Benchmark, metaclass=UserFacingBenchmarkMeta): _args = ( BenchmarkArg( name="config", - description="The Lit-GPT config to use. Default is 'Llama-2-7b-hf'. See the lit_gpt_model.py for details.", + description="The Lit-GPT config to use. Default is 'Llama-2-7b-hf'. See the litgpt_model.py for details.", ), BenchmarkArg( name="batchdims", @@ -1935,7 +1935,7 @@ def make_batch(self) -> tuple[list, dict]: def fn(self) -> Callable: module = ( - lit_gpt_model.LLaMAMLP(self.config) + litgpt_model.LLaMAMLP(self.config) .to(device=self.device, dtype=self.tdtype) .requires_grad_(self.requires_grad) ) @@ -1946,7 +1946,7 @@ class LitGPTCausalSelfAttentionBenchmark(Benchmark, metaclass=UserFacingBenchmar _args = ( BenchmarkArg( name="config", - description="The Lit-GPT config to use. Default is 'Llama-2-7b-hf'. See the lit_gpt_model.py for details.", + description="The Lit-GPT config to use. Default is 'Llama-2-7b-hf'. See the litgpt_model.py for details.", ), BenchmarkArg( name="batchdims", @@ -2005,7 +2005,7 @@ def make_batch(self) -> tuple[list, dict]: def fn(self) -> Callable: module = ( - lit_gpt_model.CausalSelfAttention(self.config) + litgpt_model.CausalSelfAttention(self.config) .to(device=self.device, dtype=self.tdtype) .requires_grad_(self.requires_grad) ) @@ -2086,7 +2086,7 @@ def make_batch(self) -> tuple[list, dict]: def fn(self) -> Callable: module = ( - lit_gpt_model.RMSNorm(self.size, self.dim, self.eps) + litgpt_model.RMSNorm(self.size, self.dim, self.eps) .to(device=self.device, dtype=self.tdtype) .requires_grad_(self.requires_grad) ) @@ -2168,7 +2168,7 @@ def make_batch(self) -> tuple[list, dict]: def fn(self) -> Callable: gpt = ( - lit_gpt_model.GPT(self.config) + litgpt_model.GPT(self.config) .to(device=self.device, dtype=self.model_tdtype) .requires_grad_(self.requires_grad) ) @@ -2199,7 +2199,7 @@ def __init__(self, config, use_apex) -> None: super().__init__() self.config = config - self.apply_rope = lit_gpt_model.apply_rope + self.apply_rope = litgpt_model.apply_rope self.use_apex = use_apex def forward( @@ -2254,7 +2254,7 @@ class LlamaQKVSplitRopeBenchmark(Benchmark, metaclass=UserFacingBenchmarkMeta): _args = ( BenchmarkArg( name="config", - description="The Lit-GPT config to use. Default is 'Llama-2-7b-hf'. See the lit_gpt_model.py for details.", + description="The Lit-GPT config to use. Default is 'Llama-2-7b-hf'. See the litgpt_model.py for details.", ), BenchmarkArg( name="batchdims", @@ -2610,7 +2610,7 @@ def __init__( dtype: dtypes.dtype = thunder.bfloat16, requires_grad: bool = True, ) -> None: - from thunder.tests.lit_gpt_model import Config + from thunder.tests.litgpt_model import Config litgptconfig = Config.from_name(config) if not isinstance(config, Config) else config nanogptconfig = NanoGPTConfig( @@ -2793,7 +2793,7 @@ def __init__( # Sets required benchmark parameters self.devices: list[str] = [device] - self.cos, self.sin = lit_gpt_model.build_rope_cache( + self.cos, self.sin = litgpt_model.build_rope_cache( seq_len=seq_length, n_elem=self.config.rope_n_elem, device=self.device ) @@ -2806,9 +2806,7 @@ def make_batch(self) -> tuple[list, dict]: def fn(self) -> Callable: model = ( - lit_gpt_model.Block(self.config) - .to(device=self.device, dtype=self.tdtype) - .requires_grad_(self.requires_grad) + litgpt_model.Block(self.config).to(device=self.device, dtype=self.tdtype).requires_grad_(self.requires_grad) ) return model diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py index 9120584989..877f99f38e 100644 --- a/thunder/benchmarks/benchmark_litgpt.py +++ b/thunder/benchmarks/benchmark_litgpt.py @@ -7,7 +7,7 @@ import torch.distributed as torch_dist import thunder -from thunder.tests.lit_gpt_model import Config, GPT, Block +from thunder.tests.litgpt_model import Config, GPT, Block from lightning.fabric.utilities.throughput import measure_flops from lightning.fabric.utilities import Throughput diff --git a/thunder/benchmarks/distributed.py b/thunder/benchmarks/distributed.py index c46a699ca5..3437c47e30 100644 --- a/thunder/benchmarks/distributed.py +++ b/thunder/benchmarks/distributed.py @@ -14,7 +14,7 @@ LitGPTBenchmark, LitGPTConfig, ) -from thunder.tests.lit_gpt_model import name_to_config +from thunder.tests.litgpt_model import name_to_config from thunder.distributed import FSDPBucketingStrategy from thunder.distributed import FSDPType @@ -299,7 +299,7 @@ def parse_args() -> argparse.Namespace: from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from thunder.benchmarks import get_default_torch_fsdp_executor from thunder.tests.nanogpt_model import Block as NanoGPTBlock - from thunder.tests.lit_gpt_model import Block as GPTBlock + from thunder.tests.litgpt_model import Block as GPTBlock sharding_strategy = ShardingStrategy.SHARD_GRAD_OP auto_wrap_policies = ( diff --git a/thunder/benchmarks/targets.py b/thunder/benchmarks/targets.py index 9f02b23d35..9b38e26c3d 100644 --- a/thunder/benchmarks/targets.py +++ b/thunder/benchmarks/targets.py @@ -39,7 +39,7 @@ thunder_sdpa_torch_compile_nvfuser_executor, ) -from thunder.tests.lit_gpt_model import Config as LitGPTConfig +from thunder.tests.litgpt_model import Config as LitGPTConfig APEX_FUSED_ROPE_AVAILABLE: bool = package_available("fused_rotary_positional_embedding") diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index e626a77f2a..29046d5778 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -7,7 +7,8 @@ from collections.abc import Callable, Sequence import weakref import random -from functools import partial, wraps +from functools import partial, wraps, reduce +import operator import copy import contextvars from contextlib import contextmanager @@ -1347,6 +1348,20 @@ def bind_inputs(name, trace, input_vars, input_proxies): trace.args = input_proxies +def _get_process_group_from(*fn_and_args) -> Optional["ProcessGroup"]: + # `ddp` and `fsdp` transforms add attribute `procses_group_for_ddp` + # on the Module that they wrap. This module could be passed to `thunder.jit` + # as the function to be jitted or as an argument of the function to be jitted. + found_pg = None + for fn_or_arg in fn_and_args: + pg = getattr(fn_or_arg, "process_group_for_ddp", None) + if pg is not None and found_pg is None: + found_pg = pg + elif pg is not None and pg != found_pg: + raise NotImplementedError("jitting modules with different ProcessGroup is not supported currently.") + return found_pg + + def thunder_general_jit( fn: Callable, args, kwargs, /, *, sharp_edges: SHARP_EDGES_OPTIONS ) -> tuple[TraceCtx, TraceCtx]: @@ -1369,7 +1384,7 @@ def thunder_general_jit( si.varkwargs = ("kwargs", None) prologue_trace._siginfo = si - process_group_for_ddp = getattr(fn, "process_group_for_ddp", None) + process_group_for_ddp: Optional["ProcessGroup"] = _get_process_group_from(fn, *args, *kwargs.values()) ctx: GeneralJitCtx = GeneralJitCtx( prologue_trace, computation_trace, sharp_edges=sharp_edges, process_group_for_ddp=process_group_for_ddp ) diff --git a/thunder/executors/triton_crossentropy.py b/thunder/executors/triton_crossentropy.py index fc22770434..277e605e78 100644 --- a/thunder/executors/triton_crossentropy.py +++ b/thunder/executors/triton_crossentropy.py @@ -1,635 +1,10 @@ -import math -from enum import Enum - -import torch - - -import thunder.torch as ltorch from thunder.executors import triton_utils +from thunder.extend import OperatorExecutor - -# Requires triton 2.1 or greater -min_triton_version = "2.1" triton_version: None | str = triton_utils.triton_version() -assert triton_version is not None, f"Trying to import a Triton executor, but Triton is unavailable" -TRITON_AVAILABLE: bool = triton_utils.is_triton_version_at_least(min_triton_version) -assert ( - TRITON_AVAILABLE -), f"Trying to import a Triton executor, but it requires Triton version {min_triton_version} or greater, and the current Triton version is {triton_version}" - -from thunder.extend import OperatorExecutor, register_executor - -triton_ex: OperatorExecutor = OperatorExecutor("triton", version=triton_version) -register_executor(triton_ex) - -import triton # noqa: E402 -import triton.language as tl # noqa: E402 - -# Temporarily borrowed from https://github.com/openai/triton -FORWARD_NUM_STAGES = 1 - - -class TritonDtype(Enum): - kFP16 = 0 - kBF16 = 1 - kFP32 = 2 - kFP64 = 3 - - -_TORCH2DTYPE = { - torch.float16: TritonDtype.kFP16, - torch.bfloat16: TritonDtype.kBF16, - torch.float32: TritonDtype.kFP32, - torch.float64: TritonDtype.kFP64, -} -_DTYPE2TRITON = { - TritonDtype.kFP16: tl.float16, - TritonDtype.kBF16: tl.bfloat16, - TritonDtype.kFP32: tl.float32, - TritonDtype.kFP64: tl.float64, -} - - -@triton.jit -def _class_indices_forward( - LOGITS, - PROBS, - IDX, - LOSS, - weight, - N, - WEIGHT_BUFFER, - smoothing_factor, - log_size_logits, - WEIGHTS: tl.constexpr, - CLASS_INDICES: tl.constexpr, - LABEL_SMOOTHING: tl.constexpr, - IGNORE_INDEX: tl.constexpr, - BUFFER_DTYPE: tl.constexpr, - BLOCK: tl.constexpr, -): - buffer_dtype = _DTYPE2TRITON[BUFFER_DTYPE.value] - row = tl.program_id(0) - cols = tl.arange(0, BLOCK) - logit_start_ptrs = LOGITS + row * N - logit_ptrs = logit_start_ptrs + cols - m_prev = -float("inf") - l_prev = 0.0 - m_prev = m_prev.to(buffer_dtype) - l_prev = l_prev.to(buffer_dtype) - - for start_n in range(0, tl.cdiv(N, BLOCK)): - row_logits = tl.load( - logit_ptrs, - mask=cols < N - (start_n * BLOCK), - other=-float("inf"), - ).to(buffer_dtype) - - m_curr = tl.maximum(tl.max(row_logits, 0), m_prev) - l_prev *= tl.exp(m_prev - m_curr) - p = tl.exp(row_logits - m_curr) - l_curr = tl.sum(p, 0) + l_prev - l_prev = l_curr - m_prev = m_curr - logit_ptrs += BLOCK - logit_ptrs = logit_start_ptrs + cols - output_ptrs = PROBS + row * N + cols - WRIT_PROBS = PROBS + row * N + cols - if LABEL_SMOOTHING: - sum_total = 0.0 - sum_total = sum_total.to(buffer_dtype) - weights_total = 0.0 - weights_total = weights_total.to(buffer_dtype) - if WEIGHTS: - weight_ptr = weight + cols - - l_prev_log = tl.log(l_prev) - # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) - for start_n in range(0, tl.cdiv(N, BLOCK)): - row_logits = tl.load( - logit_ptrs, - mask=cols < N - start_n * BLOCK, - other=l_prev_log + m_prev, - ).to(buffer_dtype) - if LABEL_SMOOTHING and WEIGHTS: - full_weights_val = tl.load(weight_ptr, mask=cols < N - start_n * BLOCK, other=0.0) - weights_total += tl.sum(full_weights_val, 0) - - row_minus_max = row_logits - m_prev - log_softmax = l_prev_log - row_minus_max - - if LABEL_SMOOTHING and WEIGHTS: - log_softmax *= full_weights_val - - if LABEL_SMOOTHING: - sum_total += tl.sum(log_softmax, 0) - # Store it back - - tl.store( - WRIT_PROBS, - log_softmax, - mask=cols < N - start_n * BLOCK, - ) - logit_ptrs += BLOCK - WRIT_PROBS += BLOCK - if LABEL_SMOOTHING and WEIGHTS: - weight_ptr += BLOCK - - idx = tl.load(IDX + row) - use_class = 0.0 - if IGNORE_INDEX >= 0: - use_class = idx == IGNORE_INDEX - READ_PROBS = PROBS + row * N + idx - tl.debug_barrier() - # write-back loss - probs = tl.load(READ_PROBS) - if WEIGHTS and not LABEL_SMOOTHING: - weight_ptr = weight + idx - weights_val = tl.load(weight_ptr) - probs = weights_val * probs - if LABEL_SMOOTHING: - tl.store(WEIGHT_BUFFER + row, weights_total) - probs = (1 - smoothing_factor) * probs + smoothing_factor * (sum_total) / N - probs = probs * (1.0 - use_class) - - tl.store(LOSS + row, probs) - - -@triton.jit -def _class_probs_forward( - LOGITS, - PROBS, - IDX, - LOSS, - weight, - N, - WEIGHT_BUFFER, - smoothing_factor, - log_size_logits, - WEIGHTS: tl.constexpr, - CLASS_INDICES: tl.constexpr, - LABEL_SMOOTHING: tl.constexpr, - IGNORE_INDEX: tl.constexpr, - BUFFER_DTYPE: tl.constexpr, - BLOCK: tl.constexpr, -): - buffer_dtype = _DTYPE2TRITON[BUFFER_DTYPE.value] - row = tl.program_id(0) - cols = tl.arange(0, BLOCK) - logit_start_ptrs = LOGITS + row * N - logit_ptrs = logit_start_ptrs + cols - m_prev = -float("inf") - l_prev = 0.0 - m_prev = m_prev.to(buffer_dtype) - l_prev = l_prev.to(buffer_dtype) - - for start_n in range(0, tl.cdiv(N, BLOCK)): - row_logits = tl.load( - logit_ptrs, - mask=cols < N - (start_n * BLOCK), - other=-float("inf"), - ).to(buffer_dtype) - - m_curr = tl.maximum(tl.max(row_logits, 0), m_prev) - l_prev *= tl.exp(m_prev - m_curr) - p = tl.exp(row_logits - m_curr) - l_curr = tl.sum(p, 0) + l_prev - l_prev = l_curr - m_prev = m_curr - logit_ptrs += BLOCK - logit_ptrs = logit_start_ptrs + cols - output_ptrs = PROBS + row * N + cols - WRIT_PROBS = PROBS + row * N + cols - - sum_total = 0.0 - weights_total = 0.0 - sum_total = sum_total.to(buffer_dtype) - weights_total = weights_total.to(buffer_dtype) - idx_ptr = IDX + row * N + cols - if WEIGHTS: - weight_ptr = weight + cols - - l_prev_log = tl.log(l_prev) - # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) - for start_n in range(0, tl.cdiv(N, BLOCK)): - row_logits = tl.load( - logit_ptrs, - mask=cols < N - start_n * BLOCK, - other=l_prev_log + m_prev, - ).to(buffer_dtype) - idx = tl.load(idx_ptr, mask=cols < N - start_n * BLOCK, other=0.0) - full_weights_val = (1.0 - smoothing_factor) * idx + smoothing_factor / N - if WEIGHTS: - weights_val = tl.load(weight_ptr, mask=cols < N - start_n * BLOCK, other=0.0) - full_weights_val = weights_val * full_weights_val - else: - full_weights_val = tl.where(cols < N - start_n * BLOCK, full_weights_val, 0.0) - weights_total += tl.sum(full_weights_val, 0) - - row_minus_max = row_logits - m_prev - log_softmax = l_prev_log - row_minus_max - - log_softmax *= full_weights_val - sum_total += tl.sum(log_softmax, 0) - # Store it back - - tl.store( - WRIT_PROBS, - log_softmax, - mask=cols < N - start_n * BLOCK, - ) - logit_ptrs += BLOCK - WRIT_PROBS += BLOCK - idx_ptr += BLOCK - if WEIGHTS: - weight_ptr += BLOCK - - tl.store(WEIGHT_BUFFER + row, weights_total) - probs = sum_total - - tl.store(LOSS + row, probs) - - -@triton.autotune( - configs=[ - # fmt: off - triton.Config({'BLOCK': 1024}, num_stages=FORWARD_NUM_STAGES, num_warps=1), - triton.Config({'BLOCK': 2048}, num_stages=FORWARD_NUM_STAGES, num_warps=8), - triton.Config({'BLOCK': 4096}, num_stages=FORWARD_NUM_STAGES, num_warps=8), - triton.Config({'BLOCK': 8192}, num_stages=FORWARD_NUM_STAGES, num_warps=16), - triton.Config({'BLOCK': 16384}, num_stages=FORWARD_NUM_STAGES, num_warps=16), - # fmt: on - ], - key=[ - "N", - "CLASS_INDICES", - "log_size_logits", - "BUFFER_DTYPE", - ], -) -@triton.jit -def _forward( - LOGITS, - PROBS, - IDX, - LOSS, - weight, - N, - WEIGHT_BUFFER, - smoothing_factor, - log_size_logits, - WEIGHTS: tl.constexpr, - CLASS_INDICES: tl.constexpr, - LABEL_SMOOTHING: tl.constexpr, - IGNORE_INDEX: tl.constexpr, - BUFFER_DTYPE: tl.constexpr, - BLOCK: tl.constexpr, -): - if CLASS_INDICES: - _class_indices_forward( - LOGITS, - PROBS, - IDX, - LOSS, - weight, - N, - WEIGHT_BUFFER, - smoothing_factor, - log_size_logits, - WEIGHTS, - CLASS_INDICES, - LABEL_SMOOTHING, - IGNORE_INDEX, - BUFFER_DTYPE, - BLOCK, - ) - else: - _class_probs_forward( - LOGITS, - PROBS, - IDX, - LOSS, - weight, - N, - WEIGHT_BUFFER, - smoothing_factor, - log_size_logits, - WEIGHTS, - CLASS_INDICES, - LABEL_SMOOTHING, - IGNORE_INDEX, - BUFFER_DTYPE, - BLOCK, - ) - - -@triton.autotune( - configs=[ - # fmt: off - triton.Config({'BLOCK': 1024}, num_stages=1, num_warps=1), - triton.Config({'BLOCK': 2048}, num_stages=1, num_warps=8), - triton.Config({'BLOCK': 4096}, num_stages=1, num_warps=8), - triton.Config({'BLOCK': 8192}, num_stages=1, num_warps=16), - triton.Config({'BLOCK': 16384}, num_stages=1, num_warps=16), - # fmt: on - ], - key=[ - "N", - "CLASS_INDICES", - "log_size_logits", - "BUFFER_DTYPE", - ], -) -@triton.jit -def _backward( - PROBS, - IDX, - DPROBS, - dprob_stride, - DIN, - weight, - N, - WEIGHT_BUFFER, - smoothing_factor, - log_size_logits, - WEIGHTS: tl.constexpr, - CLASS_INDICES: tl.constexpr, - LABEL_SMOOTHING: tl.constexpr, - IGNORE_INDEX: tl.constexpr, - BUFFER_DTYPE: tl.constexpr, - BLOCK: tl.constexpr, -): - buffer_dtype = _DTYPE2TRITON[BUFFER_DTYPE.value] - row = tl.program_id(0) - start_n = tl.program_id(1) - cols = tl.arange(0, BLOCK) - PROBS = PROBS + row * N - # pointers to probs - probs_start = PROBS + cols + BLOCK * start_n - # for start_n in range(0, tl.cdiv(N, BLOCK)): # need to change this - probs = -tl.load( - probs_start, - mask=cols < N - (start_n * BLOCK), - other=float("inf"), - ).to(buffer_dtype) - DIN = DIN + row * N + cols + BLOCK * start_n - dout = tl.load(DPROBS + row * dprob_stride).to(buffer_dtype) - # We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k] - # and we have -log(p[k]) stored in PROBS, so this is easy - if CLASS_INDICES: - idx = tl.load(IDX + row) - delta = ((start_n * BLOCK) + cols) == idx - # write result in-place in PROBS - if IGNORE_INDEX >= 0: - use_class = idx == IGNORE_INDEX - dout = dout * (1 - use_class) - if LABEL_SMOOTHING: - if WEIGHTS: - weight_ptr = weight + cols + BLOCK * start_n - full_weights_val = tl.load(weight_ptr, mask=cols < N - start_n * BLOCK, other=0.0).to(buffer_dtype) - weights_val = tl.load(weight + idx) - probs = probs / full_weights_val - probs = tl.exp(probs) - if WEIGHTS: - weights_total = tl.load(WEIGHT_BUFFER + row) - numerator_contrib = weights_val * (1.0 - smoothing_factor) * (probs - delta) - mean_contrib = ((weights_total * probs) - (full_weights_val)) * smoothing_factor / N - else: - numerator_contrib = (1.0 - smoothing_factor) * (probs - delta) - mean_contrib = (smoothing_factor * probs) - (smoothing_factor / N) - - din = (numerator_contrib + mean_contrib) * dout - - else: - probs = tl.exp(probs) - din = (probs - delta) * dout - if WEIGHTS: - weight_ptr = weight + idx - weights_val = tl.load(weight_ptr) - din = weights_val * din - else: - idx = tl.load( - IDX + row * N + cols + BLOCK * start_n, - mask=cols < N - start_n * BLOCK, - other=0.0, - ).to(buffer_dtype) - full_weights_val = (1.0 - smoothing_factor) * idx + smoothing_factor / N - weights_total = tl.load(WEIGHT_BUFFER + row) - if WEIGHTS: - weight_ptr = weight + cols + BLOCK * start_n - weights_val = tl.load(weight_ptr, mask=cols < N - start_n * BLOCK, other=0.0).to(buffer_dtype) - full_weights_val = weights_val * full_weights_val - probs = probs / full_weights_val - probs = tl.exp(probs.to(buffer_dtype)) - weighted_probs = probs * weights_total - weighted_probs_per_class = weighted_probs - full_weights_val - din = (weighted_probs_per_class) * dout - - tl.store(DIN, din.to(DIN.dtype.element_ty), mask=cols + BLOCK * start_n < N) - - -class CrossEntropy(torch.autograd.Function): - @staticmethod - def forward( - ctx, - logits, - indices, - weight, - ignore_index, - reduction, - label_smoothing, - ): - buffer_dtype = None - # make sure we can use triton - # assert ( - # indices.dtype == torch.int64 - # ), "Indices are expected to be of type long." - assert weight is None or (len(weight.shape) == 1 and weight.shape[0] == logits.shape[-1]) - # make kernel - if buffer_dtype is None: - if logits.dtype in [torch.bfloat16, torch.float16]: - buffer_dtype = torch.float32 - else: - buffer_dtype = logits.dtype - buffer_dtype_enum = _TORCH2DTYPE[buffer_dtype] - device, dtype = logits.device, logits.dtype - n_cols = logits.shape[-1] - # run the kernel - result = torch.empty((logits.shape[0],), dtype=dtype, device=device) - # result = torch.empty_like(indices, dtype=dtype, device=device) - neg_logprobs = torch.empty_like(logits, dtype=buffer_dtype, device=device) - weights_buffer = torch.empty_like(result, dtype=buffer_dtype) - grid = lambda opt: (logits.numel() // n_cols,) - log_size_logits = int(math.log(math.prod(logits.shape) / n_cols)) - _forward[grid]( - logits, - neg_logprobs, - indices, - result, - weight, - n_cols, - weights_buffer, - label_smoothing, - log_size_logits, - WEIGHTS=(weight is not None), - CLASS_INDICES=(indices.dtype == torch.int64), - LABEL_SMOOTHING=(label_smoothing > 0.0), - IGNORE_INDEX=ignore_index, - BUFFER_DTYPE=buffer_dtype_enum, - ) - # save for backward - ctx.save_for_backward(neg_logprobs, indices, weights_buffer) - ctx.WEIGHT = weight - ctx.label_smoothing = label_smoothing - ctx.ignore_index = ignore_index - ctx.reduction = reduction - ctx.buffer_dtype = buffer_dtype_enum - if reduction == "none": - return result - elif reduction == "sum": - return result.sum(dim=0) - elif reduction == "mean": - if indices.dtype == torch.int64: - denom = (indices != ignore_index).float() - if weight is not None: - class_weights = weight[indices] - denom *= class_weights - denom = denom.sum() - else: - denom = indices.shape[0] - ctx.denom = denom - return (result.sum(dim=0) / denom).to(dtype) - - @staticmethod - def backward(ctx, dneg_logprobs): - """We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k] - so we initialize the gradient as neg_logprobs, so we can just exponentiate - to get p[k], which is most of what we need... neg_logprobs will be - modified in place to become the gradient we want - """ - # load saved tensors - reduction = ctx.reduction - if reduction == "mean" or reduction == "sum": - dneg_logprobs = dneg_logprobs.expand(1) - neg_logprobs, indices, weights_buffer = ctx.saved_tensors - din = torch.empty_like(neg_logprobs) - weight = ctx.WEIGHT - buffer_dtype = ctx.buffer_dtype - # run the kernel - # neg_logprobs will be modified in place to become our gradient: - n_cols = neg_logprobs.shape[-1] - grid = lambda opt: ( - neg_logprobs.numel() // n_cols, - triton.cdiv(n_cols, opt["BLOCK"]), - ) - log_size_logits = int(math.log(math.prod(neg_logprobs.shape) / n_cols)) - _backward[grid]( - neg_logprobs, - indices, - dneg_logprobs, - dneg_logprobs.stride(0), - din, - weight, - n_cols, - weights_buffer, - ctx.label_smoothing, - log_size_logits, - WEIGHTS=(weight is not None), - CLASS_INDICES=(indices.dtype == torch.int64), - LABEL_SMOOTHING=(ctx.label_smoothing > 0.0), - IGNORE_INDEX=ctx.ignore_index, - BUFFER_DTYPE=buffer_dtype, - ) - if ctx.reduction == "mean": - din /= ctx.denom - return din, None, None, None, None, None, None - - -def cross_entropy( - input, - target, - weight=None, - ignore_index=-100, - reduction="mean", - label_smoothing=0.0, -): - r""" - Returns the Cross Entropy loss of input. If the target is class indcies - then the ignore_index argument is applicable, while the label_smoothing argument - is not. On the other hand, if the target is class probabilites, then the - label_smoothing argument is applicable, while the ignore_index argument is not. - - Args: - input: Tensor of shape (B, N) - where B is the batch dim and N is the number of classes - target: Int Tensor of shape (B,), min = 0, max = N-1 or - Float Tensor of shape (B, N), rows sum to 1.0 - Int tensor of class labels. - weight: Optional, Float Tensor of shape (N,) - Weight to scale each class - ignore_index: Int, which class label should be ignored - reduction: String: ['none', 'sum', 'mean'] - label_smoothing: Float between 0 and 1 - """ - return CrossEntropy.apply( - input, - target, - weight, - ignore_index, - reduction, - label_smoothing, - ) - - -# TODO: What is correct handling of ignore_index? -def cross_entropy_impl( - a, - target, - weight=None, - size_average=None, - ignore_index=-100, - reduce=None, - reduction="mean", - label_smoothing=0.0, -): - loss = cross_entropy(a, target, weight, ignore_index, reduction, label_smoothing) - - return loss - - -def cross_entropy_checker( - a, - /, - target, - weight=None, - size_average=None, - ignore_index=-100, - reduce=None, - reduction="mean", - label_smoothing=0.0, -) -> bool: - if triton is None: - return False - - torch_dtype = ltorch.to_torch_dtype(a.dtype) - if torch_dtype not in (torch.float16, torch.bfloat16, torch.float32, torch.float64): - return False - - # These arguments are deprecated and not supported - if size_average is not None or reduce is not None: - return False - - # We only support reduction of "sum", "mean" or "none" - if reduction not in ["sum", "mean", "none"]: - return False - - if len(a.shape) != 2: - return False - - return True - -import thunder.torch as ltorch +triton_ex: None | OperatorExecutor = None +if triton_version is not None: + from thunder.executors.triton_crossentropy_impl import triton_ex as impl_ex -ce = triton_ex.register_operator("triton_crossentropy", like=ltorch.cross_entropy, fn=cross_entropy_impl) -triton_ex.register_implementation(ltorch.cross_entropy, ce, checker=cross_entropy_checker) + triton_ex = impl_ex diff --git a/thunder/executors/triton_crossentropy_impl.py b/thunder/executors/triton_crossentropy_impl.py new file mode 100644 index 0000000000..ff36fcb450 --- /dev/null +++ b/thunder/executors/triton_crossentropy_impl.py @@ -0,0 +1,631 @@ +import math +from enum import Enum + +import torch + +from thunder.extend import OperatorExecutor, register_executor +from thunder.executors import triton_utils + +# Requires triton 2.1 or greater +min_triton_version = "2.1" + +triton_version: None | str = triton_utils.triton_version() +TRITON_AVAILABLE: bool = triton_utils.is_triton_version_at_least(min_triton_version) +assert ( + TRITON_AVAILABLE +), f"Trying to import a Triton executor, but it requires Triton version {min_triton_version} or greater, and the current Triton version is {triton_version}" + +triton_ex: OperatorExecutor = OperatorExecutor("triton", version=triton_version) +register_executor(triton_ex) + +import triton # noqa: E402 +import triton.language as tl # noqa: E402 + +# Temporarily borrowed from https://github.com/openai/triton +FORWARD_NUM_STAGES = 1 + + +class TritonDtype(Enum): + kFP16 = 0 + kBF16 = 1 + kFP32 = 2 + kFP64 = 3 + + +_TORCH2DTYPE = { + torch.float16: TritonDtype.kFP16, + torch.bfloat16: TritonDtype.kBF16, + torch.float32: TritonDtype.kFP32, + torch.float64: TritonDtype.kFP64, +} +_DTYPE2TRITON = { + TritonDtype.kFP16: tl.float16, + TritonDtype.kBF16: tl.bfloat16, + TritonDtype.kFP32: tl.float32, + TritonDtype.kFP64: tl.float64, +} + + +@triton.jit +def _class_indices_forward( + LOGITS, + PROBS, + IDX, + LOSS, + weight, + N, + WEIGHT_BUFFER, + smoothing_factor, + log_size_logits, + WEIGHTS: tl.constexpr, + CLASS_INDICES: tl.constexpr, + LABEL_SMOOTHING: tl.constexpr, + IGNORE_INDEX: tl.constexpr, + BUFFER_DTYPE: tl.constexpr, + BLOCK: tl.constexpr, +): + buffer_dtype = _DTYPE2TRITON[BUFFER_DTYPE.value] + row = tl.program_id(0) + cols = tl.arange(0, BLOCK) + logit_start_ptrs = LOGITS + row * N + logit_ptrs = logit_start_ptrs + cols + m_prev = -float("inf") + l_prev = 0.0 + m_prev = m_prev.to(buffer_dtype) + l_prev = l_prev.to(buffer_dtype) + + for start_n in range(0, tl.cdiv(N, BLOCK)): + row_logits = tl.load( + logit_ptrs, + mask=cols < N - (start_n * BLOCK), + other=-float("inf"), + ).to(buffer_dtype) + + m_curr = tl.maximum(tl.max(row_logits, 0), m_prev) + l_prev *= tl.exp(m_prev - m_curr) + p = tl.exp(row_logits - m_curr) + l_curr = tl.sum(p, 0) + l_prev + l_prev = l_curr + m_prev = m_curr + logit_ptrs += BLOCK + logit_ptrs = logit_start_ptrs + cols + output_ptrs = PROBS + row * N + cols + WRIT_PROBS = PROBS + row * N + cols + if LABEL_SMOOTHING: + sum_total = 0.0 + sum_total = sum_total.to(buffer_dtype) + weights_total = 0.0 + weights_total = weights_total.to(buffer_dtype) + if WEIGHTS: + weight_ptr = weight + cols + + l_prev_log = tl.log(l_prev) + # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) + for start_n in range(0, tl.cdiv(N, BLOCK)): + row_logits = tl.load( + logit_ptrs, + mask=cols < N - start_n * BLOCK, + other=l_prev_log + m_prev, + ).to(buffer_dtype) + if LABEL_SMOOTHING and WEIGHTS: + full_weights_val = tl.load(weight_ptr, mask=cols < N - start_n * BLOCK, other=0.0) + weights_total += tl.sum(full_weights_val, 0) + + row_minus_max = row_logits - m_prev + log_softmax = l_prev_log - row_minus_max + + if LABEL_SMOOTHING and WEIGHTS: + log_softmax *= full_weights_val + + if LABEL_SMOOTHING: + sum_total += tl.sum(log_softmax, 0) + # Store it back + + tl.store( + WRIT_PROBS, + log_softmax, + mask=cols < N - start_n * BLOCK, + ) + logit_ptrs += BLOCK + WRIT_PROBS += BLOCK + if LABEL_SMOOTHING and WEIGHTS: + weight_ptr += BLOCK + + idx = tl.load(IDX + row) + use_class = 0.0 + if IGNORE_INDEX >= 0: + use_class = idx == IGNORE_INDEX + READ_PROBS = PROBS + row * N + idx + tl.debug_barrier() + # write-back loss + probs = tl.load(READ_PROBS) + if WEIGHTS and not LABEL_SMOOTHING: + weight_ptr = weight + idx + weights_val = tl.load(weight_ptr) + probs = weights_val * probs + if LABEL_SMOOTHING: + tl.store(WEIGHT_BUFFER + row, weights_total) + probs = (1 - smoothing_factor) * probs + smoothing_factor * (sum_total) / N + probs = probs * (1.0 - use_class) + + tl.store(LOSS + row, probs) + + +@triton.jit +def _class_probs_forward( + LOGITS, + PROBS, + IDX, + LOSS, + weight, + N, + WEIGHT_BUFFER, + smoothing_factor, + log_size_logits, + WEIGHTS: tl.constexpr, + CLASS_INDICES: tl.constexpr, + LABEL_SMOOTHING: tl.constexpr, + IGNORE_INDEX: tl.constexpr, + BUFFER_DTYPE: tl.constexpr, + BLOCK: tl.constexpr, +): + buffer_dtype = _DTYPE2TRITON[BUFFER_DTYPE.value] + row = tl.program_id(0) + cols = tl.arange(0, BLOCK) + logit_start_ptrs = LOGITS + row * N + logit_ptrs = logit_start_ptrs + cols + m_prev = -float("inf") + l_prev = 0.0 + m_prev = m_prev.to(buffer_dtype) + l_prev = l_prev.to(buffer_dtype) + + for start_n in range(0, tl.cdiv(N, BLOCK)): + row_logits = tl.load( + logit_ptrs, + mask=cols < N - (start_n * BLOCK), + other=-float("inf"), + ).to(buffer_dtype) + + m_curr = tl.maximum(tl.max(row_logits, 0), m_prev) + l_prev *= tl.exp(m_prev - m_curr) + p = tl.exp(row_logits - m_curr) + l_curr = tl.sum(p, 0) + l_prev + l_prev = l_curr + m_prev = m_curr + logit_ptrs += BLOCK + logit_ptrs = logit_start_ptrs + cols + output_ptrs = PROBS + row * N + cols + WRIT_PROBS = PROBS + row * N + cols + + sum_total = 0.0 + weights_total = 0.0 + sum_total = sum_total.to(buffer_dtype) + weights_total = weights_total.to(buffer_dtype) + idx_ptr = IDX + row * N + cols + if WEIGHTS: + weight_ptr = weight + cols + + l_prev_log = tl.log(l_prev) + # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) + for start_n in range(0, tl.cdiv(N, BLOCK)): + row_logits = tl.load( + logit_ptrs, + mask=cols < N - start_n * BLOCK, + other=l_prev_log + m_prev, + ).to(buffer_dtype) + idx = tl.load(idx_ptr, mask=cols < N - start_n * BLOCK, other=0.0) + full_weights_val = (1.0 - smoothing_factor) * idx + smoothing_factor / N + if WEIGHTS: + weights_val = tl.load(weight_ptr, mask=cols < N - start_n * BLOCK, other=0.0) + full_weights_val = weights_val * full_weights_val + else: + full_weights_val = tl.where(cols < N - start_n * BLOCK, full_weights_val, 0.0) + weights_total += tl.sum(full_weights_val, 0) + + row_minus_max = row_logits - m_prev + log_softmax = l_prev_log - row_minus_max + + log_softmax *= full_weights_val + sum_total += tl.sum(log_softmax, 0) + # Store it back + + tl.store( + WRIT_PROBS, + log_softmax, + mask=cols < N - start_n * BLOCK, + ) + logit_ptrs += BLOCK + WRIT_PROBS += BLOCK + idx_ptr += BLOCK + if WEIGHTS: + weight_ptr += BLOCK + + tl.store(WEIGHT_BUFFER + row, weights_total) + probs = sum_total + + tl.store(LOSS + row, probs) + + +@triton.autotune( + configs=[ + # fmt: off + triton.Config({'BLOCK': 1024}, num_stages=FORWARD_NUM_STAGES, num_warps=1), + triton.Config({'BLOCK': 2048}, num_stages=FORWARD_NUM_STAGES, num_warps=8), + triton.Config({'BLOCK': 4096}, num_stages=FORWARD_NUM_STAGES, num_warps=8), + triton.Config({'BLOCK': 8192}, num_stages=FORWARD_NUM_STAGES, num_warps=16), + triton.Config({'BLOCK': 16384}, num_stages=FORWARD_NUM_STAGES, num_warps=16), + # fmt: on + ], + key=[ + "N", + "CLASS_INDICES", + "log_size_logits", + "BUFFER_DTYPE", + ], +) +@triton.jit +def _forward( + LOGITS, + PROBS, + IDX, + LOSS, + weight, + N, + WEIGHT_BUFFER, + smoothing_factor, + log_size_logits, + WEIGHTS: tl.constexpr, + CLASS_INDICES: tl.constexpr, + LABEL_SMOOTHING: tl.constexpr, + IGNORE_INDEX: tl.constexpr, + BUFFER_DTYPE: tl.constexpr, + BLOCK: tl.constexpr, +): + if CLASS_INDICES: + _class_indices_forward( + LOGITS, + PROBS, + IDX, + LOSS, + weight, + N, + WEIGHT_BUFFER, + smoothing_factor, + log_size_logits, + WEIGHTS, + CLASS_INDICES, + LABEL_SMOOTHING, + IGNORE_INDEX, + BUFFER_DTYPE, + BLOCK, + ) + else: + _class_probs_forward( + LOGITS, + PROBS, + IDX, + LOSS, + weight, + N, + WEIGHT_BUFFER, + smoothing_factor, + log_size_logits, + WEIGHTS, + CLASS_INDICES, + LABEL_SMOOTHING, + IGNORE_INDEX, + BUFFER_DTYPE, + BLOCK, + ) + + +@triton.autotune( + configs=[ + # fmt: off + triton.Config({'BLOCK': 1024}, num_stages=1, num_warps=1), + triton.Config({'BLOCK': 2048}, num_stages=1, num_warps=8), + triton.Config({'BLOCK': 4096}, num_stages=1, num_warps=8), + triton.Config({'BLOCK': 8192}, num_stages=1, num_warps=16), + triton.Config({'BLOCK': 16384}, num_stages=1, num_warps=16), + # fmt: on + ], + key=[ + "N", + "CLASS_INDICES", + "log_size_logits", + "BUFFER_DTYPE", + ], +) +@triton.jit +def _backward( + PROBS, + IDX, + DPROBS, + dprob_stride, + DIN, + weight, + N, + WEIGHT_BUFFER, + smoothing_factor, + log_size_logits, + WEIGHTS: tl.constexpr, + CLASS_INDICES: tl.constexpr, + LABEL_SMOOTHING: tl.constexpr, + IGNORE_INDEX: tl.constexpr, + BUFFER_DTYPE: tl.constexpr, + BLOCK: tl.constexpr, +): + buffer_dtype = _DTYPE2TRITON[BUFFER_DTYPE.value] + row = tl.program_id(0) + start_n = tl.program_id(1) + cols = tl.arange(0, BLOCK) + PROBS = PROBS + row * N + # pointers to probs + probs_start = PROBS + cols + BLOCK * start_n + # for start_n in range(0, tl.cdiv(N, BLOCK)): # need to change this + probs = -tl.load( + probs_start, + mask=cols < N - (start_n * BLOCK), + other=float("inf"), + ).to(buffer_dtype) + DIN = DIN + row * N + cols + BLOCK * start_n + dout = tl.load(DPROBS + row * dprob_stride).to(buffer_dtype) + # We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k] + # and we have -log(p[k]) stored in PROBS, so this is easy + if CLASS_INDICES: + idx = tl.load(IDX + row) + delta = ((start_n * BLOCK) + cols) == idx + # write result in-place in PROBS + if IGNORE_INDEX >= 0: + use_class = idx == IGNORE_INDEX + dout = dout * (1 - use_class) + if LABEL_SMOOTHING: + if WEIGHTS: + weight_ptr = weight + cols + BLOCK * start_n + full_weights_val = tl.load(weight_ptr, mask=cols < N - start_n * BLOCK, other=0.0).to(buffer_dtype) + weights_val = tl.load(weight + idx) + probs = probs / full_weights_val + probs = tl.exp(probs) + if WEIGHTS: + weights_total = tl.load(WEIGHT_BUFFER + row) + numerator_contrib = weights_val * (1.0 - smoothing_factor) * (probs - delta) + mean_contrib = ((weights_total * probs) - (full_weights_val)) * smoothing_factor / N + else: + numerator_contrib = (1.0 - smoothing_factor) * (probs - delta) + mean_contrib = (smoothing_factor * probs) - (smoothing_factor / N) + + din = (numerator_contrib + mean_contrib) * dout + + else: + probs = tl.exp(probs) + din = (probs - delta) * dout + if WEIGHTS: + weight_ptr = weight + idx + weights_val = tl.load(weight_ptr) + din = weights_val * din + else: + idx = tl.load( + IDX + row * N + cols + BLOCK * start_n, + mask=cols < N - start_n * BLOCK, + other=0.0, + ).to(buffer_dtype) + full_weights_val = (1.0 - smoothing_factor) * idx + smoothing_factor / N + weights_total = tl.load(WEIGHT_BUFFER + row) + if WEIGHTS: + weight_ptr = weight + cols + BLOCK * start_n + weights_val = tl.load(weight_ptr, mask=cols < N - start_n * BLOCK, other=0.0).to(buffer_dtype) + full_weights_val = weights_val * full_weights_val + probs = probs / full_weights_val + probs = tl.exp(probs.to(buffer_dtype)) + weighted_probs = probs * weights_total + weighted_probs_per_class = weighted_probs - full_weights_val + din = (weighted_probs_per_class) * dout + + tl.store(DIN, din.to(DIN.dtype.element_ty), mask=cols + BLOCK * start_n < N) + + +class CrossEntropy(torch.autograd.Function): + @staticmethod + def forward( + ctx, + logits, + indices, + weight, + ignore_index, + reduction, + label_smoothing, + ): + buffer_dtype = None + # make sure we can use triton + # assert ( + # indices.dtype == torch.int64 + # ), "Indices are expected to be of type long." + assert weight is None or (len(weight.shape) == 1 and weight.shape[0] == logits.shape[-1]) + # make kernel + if buffer_dtype is None: + if logits.dtype in [torch.bfloat16, torch.float16]: + buffer_dtype = torch.float32 + else: + buffer_dtype = logits.dtype + buffer_dtype_enum = _TORCH2DTYPE[buffer_dtype] + device, dtype = logits.device, logits.dtype + n_cols = logits.shape[-1] + # run the kernel + result = torch.empty((logits.shape[0],), dtype=dtype, device=device) + # result = torch.empty_like(indices, dtype=dtype, device=device) + neg_logprobs = torch.empty_like(logits, dtype=buffer_dtype, device=device) + weights_buffer = torch.empty_like(result, dtype=buffer_dtype) + grid = lambda opt: (logits.numel() // n_cols,) + log_size_logits = int(math.log(math.prod(logits.shape) / n_cols)) + _forward[grid]( + logits, + neg_logprobs, + indices, + result, + weight, + n_cols, + weights_buffer, + label_smoothing, + log_size_logits, + WEIGHTS=(weight is not None), + CLASS_INDICES=(indices.dtype == torch.int64), + LABEL_SMOOTHING=(label_smoothing > 0.0), + IGNORE_INDEX=ignore_index, + BUFFER_DTYPE=buffer_dtype_enum, + ) + # save for backward + ctx.save_for_backward(neg_logprobs, indices, weights_buffer) + ctx.WEIGHT = weight + ctx.label_smoothing = label_smoothing + ctx.ignore_index = ignore_index + ctx.reduction = reduction + ctx.buffer_dtype = buffer_dtype_enum + if reduction == "none": + return result + elif reduction == "sum": + return result.sum(dim=0) + elif reduction == "mean": + if indices.dtype == torch.int64: + denom = (indices != ignore_index).float() + if weight is not None: + class_weights = weight[indices] + denom *= class_weights + denom = denom.sum() + else: + denom = indices.shape[0] + ctx.denom = denom + return (result.sum(dim=0) / denom).to(dtype) + + @staticmethod + def backward(ctx, dneg_logprobs): + """We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k] + so we initialize the gradient as neg_logprobs, so we can just exponentiate + to get p[k], which is most of what we need... neg_logprobs will be + modified in place to become the gradient we want + """ + # load saved tensors + reduction = ctx.reduction + if reduction == "mean" or reduction == "sum": + dneg_logprobs = dneg_logprobs.expand(1) + neg_logprobs, indices, weights_buffer = ctx.saved_tensors + din = torch.empty_like(neg_logprobs) + weight = ctx.WEIGHT + buffer_dtype = ctx.buffer_dtype + # run the kernel + # neg_logprobs will be modified in place to become our gradient: + n_cols = neg_logprobs.shape[-1] + grid = lambda opt: ( + neg_logprobs.numel() // n_cols, + triton.cdiv(n_cols, opt["BLOCK"]), + ) + log_size_logits = int(math.log(math.prod(neg_logprobs.shape) / n_cols)) + _backward[grid]( + neg_logprobs, + indices, + dneg_logprobs, + dneg_logprobs.stride(0), + din, + weight, + n_cols, + weights_buffer, + ctx.label_smoothing, + log_size_logits, + WEIGHTS=(weight is not None), + CLASS_INDICES=(indices.dtype == torch.int64), + LABEL_SMOOTHING=(ctx.label_smoothing > 0.0), + IGNORE_INDEX=ctx.ignore_index, + BUFFER_DTYPE=buffer_dtype, + ) + if ctx.reduction == "mean": + din /= ctx.denom + return din, None, None, None, None, None, None + + +def cross_entropy( + input, + target, + weight=None, + ignore_index=-100, + reduction="mean", + label_smoothing=0.0, +): + r""" + Returns the Cross Entropy loss of input. If the target is class indcies + then the ignore_index argument is applicable, while the label_smoothing argument + is not. On the other hand, if the target is class probabilites, then the + label_smoothing argument is applicable, while the ignore_index argument is not. + + Args: + input: Tensor of shape (B, N) + where B is the batch dim and N is the number of classes + target: Int Tensor of shape (B,), min = 0, max = N-1 or + Float Tensor of shape (B, N), rows sum to 1.0 + Int tensor of class labels. + weight: Optional, Float Tensor of shape (N,) + Weight to scale each class + ignore_index: Int, which class label should be ignored + reduction: String: ['none', 'sum', 'mean'] + label_smoothing: Float between 0 and 1 + """ + return CrossEntropy.apply( + input, + target, + weight, + ignore_index, + reduction, + label_smoothing, + ) + + +# TODO: What is correct handling of ignore_index? +def cross_entropy_impl( + a, + target, + weight=None, + size_average=None, + ignore_index=-100, + reduce=None, + reduction="mean", + label_smoothing=0.0, +): + loss = cross_entropy(a, target, weight, ignore_index, reduction, label_smoothing) + + return loss + + +def cross_entropy_checker( + a, + /, + target, + weight=None, + size_average=None, + ignore_index=-100, + reduce=None, + reduction="mean", + label_smoothing=0.0, +) -> bool: + if triton is None: + return False + + torch_dtype = ltorch.to_torch_dtype(a.dtype) + if torch_dtype not in (torch.float16, torch.bfloat16, torch.float32, torch.float64): + return False + + # These arguments are deprecated and not supported + if size_average is not None or reduce is not None: + return False + + # We only support reduction of "sum", "mean" or "none" + if reduction not in ["sum", "mean", "none"]: + return False + + if len(a.shape) != 2: + return False + + return True + + +import thunder.torch as ltorch + +ce = triton_ex.register_operator("triton_crossentropy", like=ltorch.cross_entropy, fn=cross_entropy_impl) +triton_ex.register_implementation(ltorch.cross_entropy, ce, checker=cross_entropy_checker) diff --git a/thunder/extend/__init__.py b/thunder/extend/__init__.py index 2acdcf6427..a1fd0c8361 100644 --- a/thunder/extend/__init__.py +++ b/thunder/extend/__init__.py @@ -306,13 +306,9 @@ def get_all_executors() -> tuple[Executor]: torch_compile, torchex, transformer_engineex, + triton_crossentropy, ) - if torch.cuda.is_available(): - # raise an error when a dependency is not available at import time - # TODO: this should only happen at runtime - from thunder.executors import triton_crossentropy - return tuple(_executor_map.values()) diff --git a/thunder/tests/distributed/test_ddp.py b/thunder/tests/distributed/test_ddp.py index c321a01394..3562a10a5c 100644 --- a/thunder/tests/distributed/test_ddp.py +++ b/thunder/tests/distributed/test_ddp.py @@ -843,6 +843,20 @@ def check_inflight_allgather_number(trc, n: int, is_bucket: bool): aft_trc = limit_in_flight_allgathers(bwd_trc, i, is_bucketing) check_inflight_allgather_number(aft_trc, i, is_bucketing) + def test_ddp_model_as_argument(self): + # Sanity test to make sure passing model as argument to + # thunder.jit with `ddp` compiles. + device = torch.device("cuda", self.rank) + model = torch.nn.Linear(5, 10, bias=False, device=device) + x = torch.randn(2, 5, device=device) + + def fwd_loss(m, x): + return m(x).sum() + + model = thunder.distributed.ddp(model) + fwd_loss = thunder.jit(fwd_loss) + fwd_loss(model, x) + common_utils.instantiate_parametrized_tests(CompileDDPTest) diff --git a/thunder/tests/lit_gpt_model.py b/thunder/tests/litgpt_model.py similarity index 82% rename from thunder/tests/lit_gpt_model.py rename to thunder/tests/litgpt_model.py index 57a85089bc..23b51545a0 100644 --- a/thunder/tests/lit_gpt_model.py +++ b/thunder/tests/litgpt_model.py @@ -1,4 +1,4 @@ -"""Taken from https://github.com/Lightning-AI/lit-gpt/blob/main/lit_gpt/model.py""" +"""Taken from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py""" import torch import torch.nn as nn @@ -18,9 +18,9 @@ rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="RMSNorm", + norm_class_name="RMSNorm", norm_eps=1e-6, - _mlp_class="LLaMAMLP", + mlp_class_name="LLaMAMLP", intermediate_size=1376, ), dict( @@ -34,8 +34,8 @@ rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="RMSNorm", - _mlp_class="LLaMAMLP", + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", intermediate_size=11008, rope_condense_ratio=4, ), @@ -49,8 +49,8 @@ rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="RMSNorm", - _mlp_class="LLaMAMLP", + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", intermediate_size=1376, ), dict( @@ -87,9 +87,9 @@ rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="RMSNorm", + norm_class_name="RMSNorm", norm_eps=1e-05, - _mlp_class="LLaMAMLP", + mlp_class_name="LLaMAMLP", intermediate_size=1376, rope_base=1000000, ), @@ -104,9 +104,9 @@ n_query_groups=8, parallel_residual=False, bias=False, - _norm_class="RMSNorm", + norm_class_name="RMSNorm", norm_eps=1e-05, - _mlp_class="LLaMAMoE", + mlp_class_name="LLaMAMoE", intermediate_size=224, rope_base=1000000, n_expert=8, @@ -150,21 +150,20 @@ def reset_parameters(self) -> None: torch.nn.init.zeros_(self.v) -import lit_gpt -import lit_gpt.rmsnorm +import litgpt # override for operator workarounds -lit_gpt.model.KVCache = OverridenKVCache +litgpt.model.KVCache = OverridenKVCache # add the testing configurations -lit_gpt.config.name_to_config.update(name_to_config) -name_to_config.update(lit_gpt.config.name_to_config) +litgpt.config.name_to_config.update(name_to_config) +name_to_config.update(litgpt.config.name_to_config) # manually expose for backwards compatibility -Config = lit_gpt.Config -GPT = lit_gpt.GPT -RMSNorm = lit_gpt.rmsnorm.RMSNorm -CausalSelfAttention = lit_gpt.model.CausalSelfAttention -LLaMAMLP = lit_gpt.model.LLaMAMLP -build_rope_cache = lit_gpt.model.build_rope_cache -apply_rope = lit_gpt.model.apply_rope -Block = lit_gpt.model.Block +Config = litgpt.Config +GPT = litgpt.GPT +RMSNorm = litgpt.model.RMSNorm +CausalSelfAttention = litgpt.model.CausalSelfAttention +LLaMAMLP = litgpt.model.LLaMAMLP +build_rope_cache = litgpt.model.build_rope_cache +apply_rope = litgpt.model.apply_rope +Block = litgpt.model.Block diff --git a/thunder/tests/test_cudnn_executor.py b/thunder/tests/test_cudnn_executor.py index 4128e02914..ab985aecbf 100644 --- a/thunder/tests/test_cudnn_executor.py +++ b/thunder/tests/test_cudnn_executor.py @@ -1,4 +1,3 @@ -import random from functools import partial from typing import Any @@ -35,24 +34,19 @@ def grad_scaled_dot_product_attention_reference_generator(op, device, dtype, req n_head = 2 N = 8 # batch size - - # TODO: multiple of 8 seems to produce NaNs - L = random.randint(1, 10) * 64 # query's sequence length - - alignment_factor = 8 - S = random.randint(1, 10) * alignment_factor # key/value's sequence length - E = random.randint(8, 16) * alignment_factor # query/key's embedding size - Ev = random.randint(8, 16) * alignment_factor # value's embedding size + L = 640 # query's sequence length + S = 80 # key/value's sequence length + E = 128 # query/key's embedding size + Ev = 64 # value's embedding size # 4-dim (multiheaded) causal cases q, k, v = make(N, n_head, L, E), make(N, n_head, S, E), make(N, n_head, S, Ev) yield SampleInput(q, k, v, None, dropout_p=0.0, is_causal=True) - # TODO: cudnnex seems to have a few mismatches. Will be enabled in a later PR. # Non-contiguous input tensor case - nq = make(N, n_head, L, E).permute(0, 1, 3, 2) - nk = make(N, n_head, L, E).permute(0, 1, 3, 2) - nv = make(N, n_head, L, E).permute(0, 1, 3, 2) + nq = make(N, n_head, E, L).permute(0, 1, 3, 2) + nk = make(N, n_head, E, S).permute(0, 1, 3, 2) + nv = make(N, n_head, Ev, S).permute(0, 1, 3, 2) yield SampleInput(nq, nk, nv, None, dropout_p=0.0, is_causal=False) # Test the scale factor which was added in torch 2.1 diff --git a/thunder/tests/test_extend.py b/thunder/tests/test_extend.py index b7dd300d12..592943d9e7 100644 --- a/thunder/tests/test_extend.py +++ b/thunder/tests/test_extend.py @@ -10,6 +10,7 @@ from thunder.core.proxies import TensorProxy from thunder.core.transforms import grad, get_grad, put_grads from thunder.extend import OperatorExecutor, register_executor, deregister_executor, get_all_executors +from lightning_utilities.core.imports import package_available def test_extend_core(): @@ -127,7 +128,9 @@ def test_get_all_executors_includes_all_native_executors(): "python", "transformer_engine", } - actual.discard("triton") # remove when triton can always be imported + if package_available("triton"): + # `triton` maybe installed on a system without GPU. + expected.update({"triton"}) if torch.cuda.is_available(): expected.update({"nvfuser"}) assert actual == expected diff --git a/thunder/tests/test_interpreter.py b/thunder/tests/test_interpreter.py index 447647b8bb..3f0dfb0a94 100644 --- a/thunder/tests/test_interpreter.py +++ b/thunder/tests/test_interpreter.py @@ -1337,7 +1337,7 @@ def foo(a): def foo(): # test relative import - from .lit_gpt_model import Config + from .litgpt_model import Config return Config @@ -1345,18 +1345,18 @@ def foo(): def foo(): # test relative import - from . import lit_gpt_model + from . import litgpt_model - return lit_gpt_model.Config + return litgpt_model.Config assert jit(foo)() is foo() # reload is implemented using exec of the module - from . import lit_gpt_model + from . import litgpt_model import importlib - importlib.reload(lit_gpt_model) - assert hasattr(lit_gpt_model, "GPT") + importlib.reload(litgpt_model) + assert hasattr(litgpt_model, "GPT") def test_locals_lookaside(jit): @@ -3071,7 +3071,7 @@ def test_nanogpt(jit): def test_litgpt(jit): from thunder.benchmarks import LitGPTBenchmark - from thunder.tests.lit_gpt_model import Config + from thunder.tests.litgpt_model import Config cfg: Config = Config.from_name("gpt-neox-like") bench = LitGPTBenchmark(config=cfg, device="cpu", dtype=torch.bfloat16, requires_grad=True) diff --git a/thunder/tests/test_jit_general.py b/thunder/tests/test_jit_general.py index 28fc8a54fc..3671bd12e7 100644 --- a/thunder/tests/test_jit_general.py +++ b/thunder/tests/test_jit_general.py @@ -16,7 +16,7 @@ import thunder from thunder.core.interpreter import is_jitting, InterpreterError -from thunder.tests import lit_gpt_model +from thunder.tests import litgpt_model import thunder.clang as clang from thunder.core.options import INTERPRETATION_OPTIONS, CACHE_OPTIONS import thunder.torch as ltorch @@ -505,7 +505,7 @@ def h(d, c): def test_litgpt(): from thunder.benchmarks import LitGPTBenchmark - from thunder.tests.lit_gpt_model import Config + from thunder.tests.litgpt_model import Config cfg: Config = Config.from_name("gpt-neox-like") bench = LitGPTBenchmark(config=cfg, device="cpu", dtype=torch.bfloat16, requires_grad=True) @@ -627,16 +627,16 @@ def test_litgpt_variants(name, device): device = torch.device(device) x = torch.randint(0, 200, (5, 5), device=device) - config = lit_gpt_model.Config.from_name(name) + config = litgpt_model.Config.from_name(name) with device: - reference = lit_gpt_model.GPT(config) + reference = litgpt_model.GPT(config) expected_logits = reference(x) expected_logits.sum().backward() with device: - model = lit_gpt_model.GPT(config) + model = litgpt_model.GPT(config) model.load_state_dict(reference.state_dict()) tom = thunder.jit(model, executors=nvfuserex if device.type == "cuda" else torchex) actual_logits = tom(x) @@ -677,10 +677,10 @@ def test_litgpt_variants_kvcache(name, device): device = torch.device(device) x = torch.randint(0, 200, (1, 2), device=device) - config = lit_gpt_model.Config.from_name(name) + config = litgpt_model.Config.from_name(name) with device: - model = lit_gpt_model.GPT(config) + model = litgpt_model.GPT(config) model.max_seq_length = 3 for p in model.parameters():