diff --git a/.azure/gpu-tests.yml b/.azure/gpu-tests.yml index 22ba01eddc..1bf99b0aaa 100644 --- a/.azure/gpu-tests.yml +++ b/.azure/gpu-tests.yml @@ -65,8 +65,14 @@ jobs: # drop pt from requirements so not to interfere with the existing one bash .azure/remove-torch-lines.sh requirements/base.txt cat requirements/base.txt + # double check on test requirements pip install -r requirements/test.txt + + # https://docs.codecov.com/docs/codecov-uploader + curl -Os https://uploader.codecov.io/latest/linux/codecov + chmod +x codecov + # install this package python setup.py develop displayName: 'Install package & ...' @@ -85,6 +91,12 @@ jobs: --durations=250 \ --numprocesses=9 \ --ignore=thunder/tests/distributed --ignore=thunder/tests/test_networks.py + # compile coverage results + python -m coverage report + python -m coverage xml + # upload to codecov + ./codecov --token=$(CODECOV_TOKEN) --commit=$(Build.SourceVersion) \ + --flags=gpu,pytest,regular --name="GPU-coverage" --env=linux,azure condition: ne(variables['testing'], 'distributed') displayName: 'Testing: regular' @@ -95,6 +107,12 @@ jobs: thunder/tests/test_networks.py \ -m "not standalone" \ -v --random-order-seed=42 --durations=0 --numprocesses=3 + # compile coverage results + python -m coverage report + python -m coverage xml + # upload to codecov + ./codecov --token=$(CODECOV_TOKEN) --commit=$(Build.SourceVersion) \ + --flags=gpu,pytest,networks --name="GPU-coverage" --env=linux,azure condition: ne(variables['testing'], 'distributed') displayName: 'Testing: networks' @@ -108,6 +126,12 @@ jobs: - bash: | # run all found tests in given past as standalone bash scripts/run_standalone_tests.sh "thunder/tests/distributed" + # compile coverage results + python -m coverage report + python -m coverage xml + # upload to codecov + ./codecov --token=$(CODECOV_TOKEN) --commit=$(Build.SourceVersion) \ + --flags=gpu,pytest,distributed --name="GPU-coverage" --env=linux,azure condition: eq(variables['testing'], 'distributed') displayName: 'Testing: distributed' diff --git a/.github/workflows/ci-checks.yml b/.github/workflows/ci-checks.yml index ae28b10c53..37fff68466 100644 --- a/.github/workflows/ci-checks.yml +++ b/.github/workflows/ci-checks.yml @@ -10,10 +10,10 @@ concurrency: cancel-in-progress: ${{ ! (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')) }} jobs: - #check-code: - # uses: Lightning-AI/utilities/.github/workflows/check-code.yml@main - # with: - # actions-ref: main + precommit-run: + uses: Lightning-AI/utilities/.github/workflows/check-precommit.yml@v0.10.1 + with: + python-version: "3.10" check-schema: uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.11.0 diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index cbd8fb2aa6..310279fbbd 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -114,15 +114,15 @@ jobs: coverage report coverage xml - #- name: Upload coverage to Codecov - # uses: codecov/codecov-action@v3 - # with: - # token: ${{ secrets.CODECOV_TOKEN }} - # file: ./coverage.xml - # flags: unittests - # env_vars: OS,PYTHON - # name: codecov-umbrella - # fail_ci_if_error: false + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + token: ${{ secrets.CODECOV_TOKEN }} + file: ./coverage.xml + flags: unittests + env_vars: OS,PYTHON + name: codecov-umbrella + fail_ci_if_error: false testing-guardian: diff --git a/.github/workflows/code-lint.yml b/.github/workflows/code-lint.yml deleted file mode 100644 index 48a0ea6d77..0000000000 --- a/.github/workflows/code-lint.yml +++ /dev/null @@ -1,13 +0,0 @@ -name: Code linting - -on: [push] - -jobs: - - precommit-run: - uses: Lightning-AI/utilities/.github/workflows/check-precommit.yml@main - with: - python-version: "3.10" - push-fixes: true - secrets: - github-token: ${{ secrets.PAT_GHOST }} diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml index d37e381b40..b8d87e466d 100644 --- a/.github/workflows/docs-build.yml +++ b/.github/workflows/docs-build.yml @@ -14,7 +14,7 @@ defaults: shell: bash jobs: - build-docs: + docs-make: uses: Lightning-AI/utilities/.github/workflows/check-docs.yml@v0.11.0 with: python-version: "3.10" @@ -22,27 +22,46 @@ jobs: install-tex: true deploy-docs: - # https://github.com/marketplace/actions/deploy-to-github-pages - needs: build-docs - if: github.event_name != 'pull_request' + needs: docs-make + if: github.repository_owner == 'Lightning-AI' && github.event_name == 'push' runs-on: ubuntu-latest + env: + GCP_TARGET: "gs://lightning-docs-thunder" steps: - - name: Checkout 🛎️ - uses: actions/checkout@v4 - with: - # If you're using actions/checkout@v4 you must set persist-credentials to false in most cases for the deployment to work correctly. - persist-credentials: false - uses: actions/download-artifact@v3 with: name: docs-html-${{ github.sha }} path: docs/build/ - - name: Deploy 🚀 - uses: JamesIves/github-pages-deploy-action@v4.5.0 + - name: Authenticate to Google Cloud + uses: google-github-actions/auth@v2 + with: + credentials_json: ${{ secrets.GCS_SA_KEY }} + + - name: Setup gcloud + uses: google-github-actions/setup-gcloud@v2 with: - token: ${{ secrets.GITHUB_TOKEN }} - branch: gh-pages # The branch the action should deploy to. - folder: docs/build/html # The folder the action should deploy. - clean: true # Automatically remove deleted files from the deploy branch - target-folder: docs # If you'd like to push the contents of the deployment folder into a specific directory - single-commit: true # you'd prefer to have a single commit on the deployment branch instead of full history + project_id: ${{ secrets.GCS_PROJECT }} + + # Uploading docs to GCS, so they can be served on lightning.ai + #- name: Upload docs/thunder/stable to GCS 🪣 + # if: startsWith(github.ref, 'refs/heads/release/') + # run: gsutil -m rsync -d -R docs/build/html/ ${GCP_TARGET}/stable + + # Uploading docs to GCS, so they can be served on lightning.ai + - name: Upload docs/thunder/latest to GCS 🪣 + if: github.ref == 'refs/heads/main' + run: gsutil -m rsync -d -R docs/build/html/ ${GCP_TARGET}/latest + + # Uploading docs to GCS, so they can be served on lightning.ai + #- name: Upload docs/thunder/release to GCS 🪣 + # if: startsWith(github.ref, 'refs/tags/') + # run: gsutil -m rsync -d -R docs/build/html/ ${GCP_TARGET}/${{ github.ref_name }} + + # Uploading docs as archive to GCS, so they can be as backup + #- name: Upload docs as archive to GCS 🪣 + # if: startsWith(github.ref, 'refs/tags/') + # working-directory: docs/build + # run: | + # zip ${{ github.ref_name }}.zip -r html/ + # gsutil cp ${{ github.ref_name }}.zip ${GCP_TARGET} diff --git a/.github/workflows/greetings.yml b/.github/workflows/greetings.yml deleted file mode 100644 index bdcabdcf69..0000000000 --- a/.github/workflows/greetings.yml +++ /dev/null @@ -1,14 +0,0 @@ -name: Greetings -# https://github.com/marketplace/actions/first-interaction - -on: [issues] # pull_request - -jobs: - greeting: - runs-on: ubuntu-20.04 - steps: - - uses: actions/first-interaction@v1 - with: - repo-token: ${{ secrets.GITHUB_TOKEN }} - issue-message: 'Hi! thanks for your contribution!, great first issue!' - pr-message: 'Hey thanks for the input! Please give us a bit of time to review it!' diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7b7b8d1f6c..f16e1ef98f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,6 +17,9 @@ repos: - id: check-toml - id: check-json - id: check-added-large-files + with: + maxkb: 250 + enforce-all: true - id: check-docstring-first - id: detect-private-key diff --git a/README.md b/README.md index 7d60063864..f89dce73b0 100644 --- a/README.md +++ b/README.md @@ -125,19 +125,7 @@ Thunder is in its early stages and should not be used for production runs yet. However, it can already deliver outstanding performance on LLM model supported by [LitGPT](https://github.com/Lightning-AI/lit-gpt), such as Mistral, Llama 2, Gemma, Falcon, and others. -Run training loop for Llama, single-GPU: - -```bash -python examples/lit-gpt/train.py -``` - -Run training loop for Llama, multi-GPU, using FSDP: - -```bash -python examples/lit-gpt/train_fsdp.py -``` - -See [README.md](examples/lit-gpt/README.md) for details on running LitGPT with Thunder. +Check out [the LitGPT integration](https://github.com/Lightning-AI/litgpt/tree/main/extensions/thunder) to learn about running LitGPT and Thunder together. ## Features diff --git a/docs/source/_static/images/LightningThunderDarkModewByline.png b/docs/source/_static/images/LightningThunderDarkModewByline.png new file mode 100644 index 0000000000..310739641c Binary files /dev/null and b/docs/source/_static/images/LightningThunderDarkModewByline.png differ diff --git a/docs/source/_static/images/LightningThunderLightModewByline.png b/docs/source/_static/images/LightningThunderLightModewByline.png new file mode 100644 index 0000000000..7effda5ec9 Binary files /dev/null and b/docs/source/_static/images/LightningThunderLightModewByline.png differ diff --git a/docs/source/_static/images/normalized_training_throughput_zero2.png b/docs/source/_static/images/normalized_training_throughput_zero2.png index be6e5888c3..60dc0725a9 100644 Binary files a/docs/source/_static/images/normalized_training_throughput_zero2.png and b/docs/source/_static/images/normalized_training_throughput_zero2.png differ diff --git a/docs/source/_static/images/training_throughput_single.png b/docs/source/_static/images/training_throughput_single.png index 6c0a7029a4..ad66af29db 100644 Binary files a/docs/source/_static/images/training_throughput_single.png and b/docs/source/_static/images/training_throughput_single.png differ diff --git a/docs/source/conf.py b/docs/source/conf.py index 052fa2437c..598dc42053 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -92,7 +92,7 @@ def _transform_changelog(path_in: str, path_out: str) -> None: "sphinx.ext.linkcode", "sphinx.ext.autosummary", "sphinx.ext.napoleon", - "sphinx.ext.imgmath", + "sphinx.ext.mathjax", "myst_parser", "nbsphinx", "sphinx_autodoc_typehints", @@ -209,6 +209,11 @@ def _transform_changelog(path_in: str, path_out: str) -> None: (master_doc, project + ".tex", project + " Documentation", author, "manual"), ] +# MathJax configuration +mathjax3_config = { + "tex": {"packages": {"[+]": ["ams", "newcommand", "configMacros"]}}, +} + # -- Options for manual page output ------------------------------------------ # One entry per manual page. List of tuples diff --git a/examples/lit-gpt/.gitignore b/examples/lit-gpt/.gitignore deleted file mode 100644 index c3d41546e1..0000000000 --- a/examples/lit-gpt/.gitignore +++ /dev/null @@ -1,4 +0,0 @@ -checkpoints - -download.py -convert_hf_checkpoint.py diff --git a/examples/lit-gpt/1_forward.py b/examples/lit-gpt/1_forward.py deleted file mode 100644 index c0392e3382..0000000000 --- a/examples/lit-gpt/1_forward.py +++ /dev/null @@ -1,57 +0,0 @@ -import time - -import lightning as L -import torch -import torch._dynamo.config -import torch._inductor.config - -from thunder.tests.lit_gpt_model import GPT - - -@torch.inference_mode() -def main(name: str = "open_llama_7b", num_samples: int = 10, compile: str = "eager") -> None: - torch.set_float32_matmul_precision("high") - torch.set_default_dtype(torch.bfloat16) - device = torch.device("cuda") - - with device: - model = GPT.from_name(name) - encoded = torch.randint(0, model.config.padded_vocab_size, (10, model.max_seq_length)) - - model.eval() - - if compile == "inductor": - torch._dynamo.config.automatic_dynamic_shapes = True - torch._inductor.config.triton.unique_kernel_names = True - torch._inductor.config.coordinate_descent_tuning = True - model = torch.compile(model, fullgraph=True) - elif compile == "thunder": - import thunder - from thunder.executors.sdpaex import sdpa_ex - from thunder.executors.torch_compile import torch_compile_executor - - model = thunder.jit( - model, - disable_torch_autograd=True, - executors=[sdpa_ex, torch_compile_executor, thunder.nvfuser_executor, thunder.pytorch_executor], - ) - elif compile != "eager": - raise ValueError(compile) - - values = [] - L.seed_everything(1234) - for i in range(num_samples): - t0 = time.perf_counter() - _ = model(encoded) - torch.cuda.synchronize() - t = time.perf_counter() - t0 - values.append(t) - print(f"Time for inference {i + 1}: {t:.05f} sec total") - print(f"Best: {min(values):05f}") - print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") - - -if __name__ == "__main__": - from jsonargparse import CLI - - CLI(main) diff --git a/examples/lit-gpt/README.md b/examples/lit-gpt/README.md deleted file mode 100644 index bf0b46199b..0000000000 --- a/examples/lit-gpt/README.md +++ /dev/null @@ -1,93 +0,0 @@ -# Lit-GPT benchmarks - -## Setup - -```bash -wget -nc https://raw.githubusercontent.com/Lightning-AI/lit-gpt/1a5e7c/scripts/download.py -pip install jsonargparse huggingface_hub sentencepiece tokenizers -pip install git+https://github.com/Lightning-AI/lit-gpt@1a5e7c -``` - -## [1 forward](1_forward.py) - -```bash -python 1_forward.py --compile thunder -``` - -Runs a single forward call with a (B=10 x T=2048) tensor: - -| Method | Time ↓ | Memory ↓ | -| -------- | ------ | -------- | -| Inductor | 1.18 s | 17.38 GB | -| Thunder | 1.27 s | 16.32 GB | -| Eager | 1.48 s | 17.44 GB | - -## [Single-device training](train.py) - -```shell -# setup -python download.py --repo_id openlm-research/open_llama_3b --tokenizer_only true -# run -python train.py --compile thunder --dynamic false -``` - -Static shapes (45 iters) - -| Method | Time ↓ | Memory ↓ | -| -------- | ------ | -------- | -| Inductor | 20.1 s | 20.95 GB | -| Thunder | 21.9 s | 23.75 GB | -| Eager | 24.6 s | 24.28 GB | - -Dynamic shapes (45 iters) - -| Method | Time ↓ | Memory ↓ | -| -------- | ------- | -------- | -| Inductor | 17.0 s | 20.69 GB | -| Eager | 17.6 s | 23.91 GB | -| Thunder | ~5715 s | - | - -## [Multi-device training](train_fsdp.py) - -```shell -# setup -python download.py --repo_id openlm-research/open_llama_3b --tokenizer_only true -# run -python train_fsdp.py --devices 2 --compile thunder --stage 2 --bucketing_strategy BLOCK -``` - -Static shapes (45 iters) - -| Stage | Bucketing | Method | Time ↓ | Memory ↓ | -| ----- | --------- | -------- | ------- | -------- | -| 2 | No | Inductor | Error | Error | -| 2 | No | Thunder | 23.29 s | 26.99 GB | -| 2 | No | Eager | 27.76 s | 27.61 GB | -| | | | | | -| 2 | Block | Inductor | 21.71 s | 24.31 GB | -| 2 | Block | Thunder | 24.30 s | 26.96 GB | -| 2 | Block | Eager | 26.05 s | 27.67 GB | -| | | | | | -| 3 | No | Inductor | Error | Error | -| 3 | No | Thunder | 24.39 s | 20.25 GB | -| 3 | No | Eager | 28.56 s | 20.75 GB | -| | | | | | -| 3 | Block | Inductor | 21.76 s | 17.86 GB | -| 3 | Block | Thunder | 24.11 s | 26.93 GB | -| 3 | Block | Eager | 26.33 s | 21.23 GB | - -## Setup - -```text -Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime) -Is debug build: False -CUDA used to build PyTorch: 12.1 -CUDA runtime version: 12.3.107 -GPU 0: NVIDIA A100-SXM4-40GB -Nvidia driver version: 545.23.08 - -pytorch-triton==3.0.0+901819d2b6 -torch==2.3.0.dev20240225+cu121 -lightning-thunder==51993f9a6894f59f3779b30485e72b93d5e7b150 -nvfuser_cu121==0.1.6.dev20240226 -``` diff --git a/examples/lit-gpt/_ddp_thunder.py b/examples/lit-gpt/_ddp_thunder.py deleted file mode 100644 index 1bd07619df..0000000000 --- a/examples/lit-gpt/_ddp_thunder.py +++ /dev/null @@ -1,226 +0,0 @@ -"""Fabric Strategy to support Thunder DDP: To be upstreamed into Fabric eventually.""" - -from contextlib import nullcontext -from datetime import timedelta -from typing import TYPE_CHECKING, Any, ContextManager, Dict, List, Optional, Tuple, Union - -import torch -import torch.distributed -from lightning_utilities.core.imports import RequirementCache -from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only -from torch import Tensor -from torch.nn import Module -from typing_extensions import override - -from lightning.fabric.accelerators.accelerator import Accelerator -from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout -from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment -from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO -from lightning.fabric.plugins.precision import Precision -from lightning.fabric.strategies.launchers.subprocess_script import _SubprocessScriptLauncher -from lightning.fabric.strategies.parallel import ParallelStrategy -from lightning.fabric.strategies.strategy import TBroadcast, _BackwardSyncControl -from lightning.fabric.utilities.distributed import ( - ReduceOp, - _distributed_is_initialized, - _get_default_process_group_backend_for_device, - _init_dist_connection, - _sync_ddp_if_available, -) -from lightning.fabric.utilities.rank_zero import rank_zero_only - -if TYPE_CHECKING: - from thunder import Executor - - -_THUNDER_AVAILABLE = RequirementCache("lightning-thunder", "thunder") - - -class DDPThunderStrategy(ParallelStrategy): - def __init__( - self, - accelerator: Optional[Accelerator] = None, - parallel_devices: Optional[List[torch.device]] = None, - cluster_environment: Optional[ClusterEnvironment] = None, - checkpoint_io: Optional[CheckpointIO] = None, - precision: Optional[Precision] = None, - executors: Optional[Tuple[Union["Executor", str], ...]] = None, - process_group_backend: Optional[str] = None, - timeout: Optional[timedelta] = default_pg_timeout, - **kwargs: Any, - ): - if not _THUNDER_AVAILABLE: - raise ModuleNotFoundError(str(_THUNDER_AVAILABLE)) - super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision=precision) - self.parallel_devices = parallel_devices - self.cluster_environment: Optional[ClusterEnvironment] = cluster_environment - - self.executors = _validate_executors(executors) - self._num_nodes = 1 - self._process_group_backend: Optional[str] = process_group_backend - self._timeout: Optional[timedelta] = timeout - self._backward_sync_control = _DDPBackwardSyncControl() - self._ddp_kwargs = kwargs - - @property - @override - def root_device(self) -> torch.device: - assert self.parallel_devices is not None - return self.parallel_devices[self.local_rank] - - @property - def num_nodes(self) -> int: - return self._num_nodes - - @num_nodes.setter - def num_nodes(self, num_nodes: int) -> None: - # note that world ranks is related to num_nodes, when resetting it, need to reset world ranks - self._num_nodes = num_nodes - - @property - def num_processes(self) -> int: - return len(self.parallel_devices) if self.parallel_devices is not None else 0 - - @property - @override - def distributed_sampler_kwargs(self) -> Dict[str, Any]: - return {"num_replicas": self.num_nodes * self.num_processes, "rank": self.global_rank} - - @override - def _configure_launcher(self) -> None: - assert self.cluster_environment is not None - if not self.cluster_environment.creates_processes_externally: - self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes) - - @property - def process_group_backend(self) -> Optional[str]: - return self._process_group_backend - - @override - def _configure_launcher(self) -> None: - assert self.cluster_environment is not None - self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes) - - @override - def setup_environment(self) -> None: - super().setup_environment() - self._setup_distributed() - - @override - def setup_module(self, module: Module) -> Module: - import thunder - - module = thunder.distributed.ddp(module, **self._ddp_kwargs) - - return thunder.jit(module, executors=self.executors) - - @override - def module_to_device(self, module: Module) -> None: - module.to(self.root_device) - - @override - def all_reduce( - self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" - ) -> Tensor: - if isinstance(tensor, Tensor): - return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op) - return tensor - - @override - def barrier(self, *args: Any, **kwargs: Any) -> None: - if not _distributed_is_initialized(): - return - if torch.distributed.get_backend() == "nccl": - torch.distributed.barrier(device_ids=[self.root_device.index]) - else: - torch.distributed.barrier() - - @override - def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: - if not _distributed_is_initialized(): - return obj - - obj = [obj] - torch.distributed.broadcast_object_list(obj, src) - return obj[0] - - def _setup_distributed(self) -> None: - self._set_world_ranks() - self._process_group_backend = self._get_process_group_backend() - assert self.cluster_environment is not None - _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout) - - def _get_process_group_backend(self) -> str: - return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device) - - def _set_world_ranks(self) -> None: - if self.cluster_environment is not None: - self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) - self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) - # `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail - # additionally, for some implementations, the setter is a no-op, so it's safer to access the getter - rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank - - -def _validate_executors(executors: Optional[Tuple[Union["Executor", str], ...]]) -> Optional[Tuple["Executor", ...]]: - """Converts string executors into it's respective ``Executor`` object.""" - if executors is None: - return None - from thunder import get_all_executors - - final = [] - issues = [] - all = get_all_executors() - for executor in executors: - if isinstance(executor, str): - for existing in all: - if executor == existing.name: - final.append(existing) - break - else: - issues.append(executor) - else: - final.append(executor) - if issues: - raise ValueError(f"Did not find the executors {issues} in {all}") - return tuple(final) - - -class _DDPBackwardSyncControl(_BackwardSyncControl): - def __init__(self): - self._enabled = False - - @override - def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager: - if not getattr(module, "use_ddp", False): - raise TypeError( - "Blocking backward sync is only possible if the module passed to" - f" `{self.__class__.__name__}.no_backward_sync` is applied DDP." - f" Got: {module.__class__.__name__}." - ) - - # issue "Limitations of the current DDP no_sync implementation" has - # details on why we cannot just return `module.no_sync()` - from thunder.distributed import skip_data_parallel_grad_sync - - previous, self._enabled = self._enabled, enabled - if enabled: - return skip_data_parallel_grad_sync() - if not enabled and previous: - return _AllReduceGradsContextManager(module) - return nullcontext() - - -class _AllReduceGradsContextManager: - def __init__(self, module: Module) -> None: - self._module = module - - @override - def __enter__(self) -> None: - from thunder.distributed import _sync_grads - - _sync_grads(self._module) - - @override - def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: - pass diff --git a/examples/lit-gpt/_fsdp_thunder.py b/examples/lit-gpt/_fsdp_thunder.py deleted file mode 100644 index 133c40b1f2..0000000000 --- a/examples/lit-gpt/_fsdp_thunder.py +++ /dev/null @@ -1,420 +0,0 @@ -"""Fabric Strategy to support Thunder FSDP: To be upstreamed into Fabric eventually.""" - -import shutil -from contextlib import ExitStack, nullcontext -from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, List, Literal, Optional, Tuple, Union - -import torch -from lightning_utilities.core.imports import RequirementCache -from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only -from torch import Tensor -from torch.nn import Module -from torch.optim import Optimizer -from typing_extensions import override - -from lightning.fabric.accelerators.accelerator import Accelerator -from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment -from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO -from lightning.fabric.plugins.precision import Precision -from lightning.fabric.strategies.launchers.subprocess_script import _SubprocessScriptLauncher -from lightning.fabric.strategies.parallel import ParallelStrategy -from lightning.fabric.strategies.strategy import TBroadcast, _apply_filter, _Sharded, _validate_keys_for_strict_loading -from lightning.fabric.utilities.distributed import ( - ReduceOp, - _distributed_is_initialized, - _get_default_process_group_backend_for_device, - _init_dist_connection, - _sync_ddp_if_available, -) -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 -from lightning.fabric.utilities.load import _METADATA_FILENAME, _move_state_into -from lightning.fabric.utilities.rank_zero import rank_zero_only -from lightning.fabric.utilities.seed import reset_seed -from lightning.fabric.utilities.types import _PATH, _Stateful - -if TYPE_CHECKING: - from thunder import Executor - from thunder.distributed import FSDPBucketingStrategy, FSDPType - from thunder.distributed.checkpoint import StateDictOptions - - _FSDP_TYPE = Union[FSDPType, Literal["ZERO2", "ZERO3"]] - _BUCKETING_STRATEGY = Union[FSDPBucketingStrategy, Literal["NONE", "LAYER", "BLOCK"]] - - -_THUNDER_AVAILABLE = RequirementCache("lightning-thunder", "thunder") - - -class FSDPThunderStrategy(ParallelStrategy, _Sharded): - def __init__( - self, - accelerator: Optional[Accelerator] = None, - parallel_devices: Optional[List[torch.device]] = None, - cluster_environment: Optional[ClusterEnvironment] = None, - checkpoint_io: Optional[CheckpointIO] = None, - precision: Optional[Precision] = None, - sharding_strategy: "_FSDP_TYPE" = "ZERO3", - bucketing_strategy: "_BUCKETING_STRATEGY" = "NONE", - executors: Optional[Tuple[Union["Executor", str], ...]] = None, - state_dict_type: Literal["full", "sharded"] = "sharded", - **kwargs: Any, - ): - if not _TORCH_GREATER_EQUAL_2_2: - raise ImportError("Thunder's FSDP strategy requires PyTorch 2.2 or higher.") - if not _THUNDER_AVAILABLE: - raise ModuleNotFoundError(str(_THUNDER_AVAILABLE)) - super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision=precision) - self.parallel_devices = parallel_devices - self.cluster_environment: Optional[ClusterEnvironment] = cluster_environment - from thunder.distributed import FSDPBucketingStrategy, FSDPType - - self.sharding_strategy = ( - FSDPType[sharding_strategy.upper()] if isinstance(sharding_strategy, str) else sharding_strategy - ) - self.bucketing_strategy = ( - FSDPBucketingStrategy[bucketing_strategy.upper()] - if isinstance(bucketing_strategy, str) - else bucketing_strategy - ) - self.executors = _validate_executors(executors) - self._state_dict_type = state_dict_type - self._fsdp_kwargs = kwargs - - @property - @override - def root_device(self) -> torch.device: - assert self.parallel_devices is not None - return self.parallel_devices[self.local_rank] - - @property - def num_nodes(self) -> int: - return 1 - - @property - def num_processes(self) -> int: - return len(self.parallel_devices) if self.parallel_devices is not None else 0 - - @property - @override - def distributed_sampler_kwargs(self) -> Dict[str, Any]: - return {"num_replicas": self.num_nodes * self.num_processes, "rank": self.global_rank} - - @override - def _configure_launcher(self) -> None: - assert self.cluster_environment is not None - if not self.cluster_environment.creates_processes_externally: - self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes) - - @override - def setup_environment(self) -> None: - super().setup_environment() - self._setup_distributed() - - @override - def setup_module(self, module: Module) -> Module: - import thunder - - module = thunder.distributed.fsdp( - module, - device=self.root_device, - sharding_strategy=self.sharding_strategy, - bucketing_strategy=self.bucketing_strategy, - **self._fsdp_kwargs, - ) - - # NOTE @IvanYaschuck says that `fsdp(jit(model))` could be supported in the future so that the user owns the `jit` call. - # we would still `jit(fsdp(undo_jit(jit(model))))` internally - return thunder.jit(module, executors=self.executors) - - @override - def module_to_device(self, module: Module) -> None: - pass - - @override - def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager: - precision_init_ctx = self.precision.module_init_context() - module_sharded_ctx = self.module_sharded_context() - stack = ExitStack() - if empty_init: - # Materialization happens in `setup`. When modules get wrapped by FSDP - stack.enter_context(torch.device("meta")) - stack.enter_context(precision_init_ctx) - stack.enter_context(module_sharded_ctx) - return stack - - @override - def module_sharded_context(self) -> ContextManager: - return nullcontext() - - @override - def all_reduce( - self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" - ) -> Tensor: - if isinstance(tensor, Tensor): - return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op) - return tensor - - @override - def barrier(self, *args: Any, **kwargs: Any) -> None: - if not _distributed_is_initialized(): - return - if torch.distributed.get_backend() == "nccl": - torch.distributed.barrier(device_ids=[self.root_device.index]) - else: - torch.distributed.barrier() - - @override - def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: - if not _distributed_is_initialized(): - return obj - - obj = [obj] - torch.distributed.broadcast_object_list(obj, src) - return obj[0] - - @override - def clip_gradients_norm( - self, - module: Module, - optimizer: Optimizer, - max_norm: Union[float, int], - norm_type: Union[float, int] = 2.0, - error_if_nonfinite: bool = True, - ) -> Tensor: - raise NotImplementedError - - @override - def save_checkpoint( - self, - path: _PATH, - state: Dict[str, Union[Module, Optimizer, Any]], - storage_options: Optional[Any] = None, - filter: Optional[Dict[str, Callable[[str, Any], bool]]] = None, - ) -> None: - if storage_options is not None: - raise TypeError( - "`FSDPStrategy.save_checkpoint(..., storage_options=...)` is not supported because" - " `FSDPStrategy` does not use the `CheckpointIO`." - ) - if filter is not None: - raise NotImplementedError("Filtering checkpoint paths is not implemented") - - # broadcast the path from rank 0 to ensure all the states are saved in a common path - path = Path(self.broadcast(path)) - if path.is_dir() and self._state_dict_type == "full" and not _is_sharded_checkpoint(path): - raise IsADirectoryError(f"The checkpoint path exists and is a directory: {path}") - - from thunder.distributed.checkpoint import save, has_fsdp_modules, StateDictOptions - - modules = [module for module in state.values() if has_fsdp_modules(module)] - if len(modules) == 0: - raise ValueError( - "Could not find a FSDP model in the provided checkpoint state. Please provide the model as" - " part of the state like so: `save_checkpoint(..., state={'model': model, ...})`. Make sure" - " you set up the model (and optimizers if any) through the strategy before saving the checkpoint." - ) - if len(modules) > 1: - raise ValueError( - "Found multiple FSDP models in the given state. Saving checkpoints with FSDP is" - " currently limited to a single model per checkpoint. To save multiple models, call the" - " save method for each model separately with a different path." - ) - - if self._state_dict_type == "sharded": - if _is_full_checkpoint(path): - path.unlink() - path.mkdir(parents=True, exist_ok=True) - - options = StateDictOptions(full_state_dict=False, cpu_offload=True, rank0_only=False) - converted_state, metadata = _get_state_dict(state, filter, options, self.local_rank) - save(converted_state, path) - if self.global_rank == 0: - torch.save(metadata, path / _METADATA_FILENAME) - - elif self._state_dict_type == "full": - if _is_sharded_checkpoint(path): - shutil.rmtree(path) - - options = StateDictOptions(full_state_dict=True, cpu_offload=True, rank0_only=True) - converted_state, metadata = _get_state_dict(state, filter, options, self.local_rank) - converted_state.update(metadata) - if self.global_rank == 0: - torch.save(converted_state, path) - else: - raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}") - - @override - def load_checkpoint( - self, - path: _PATH, - state: Optional[Union[Module, Optimizer, Dict[str, Union[Module, Optimizer, Any]]]] = None, - strict: bool = True, - ) -> Dict[str, Any]: - if not state: - raise ValueError( - f"Got `FSDPStrategy.load_checkpoint(..., state={state!r})` but a state with at least" - " a model instance to reload is required. Pass it in like so:" - " `FSDPStrategy.load_checkpoint(..., state={'model': model, ...})`" - ) - # broadcast the path from rank 0 to ensure all the states are loaded from a common path - path = Path(self.broadcast(path)) - - from thunder.distributed.checkpoint import has_fsdp_modules, StateDictOptions, load_model_state_dict, load - - if isinstance(state, Module): - if not _is_full_checkpoint(path): - raise ValueError( - "Failed to load checkpoint directly into the model. The given path must be a single file" - f" containing the full state dict: {path}" - ) - state_dict = torch.load(str(path), mmap=True, map_location="cpu") - options = StateDictOptions(full_state_dict=True, cpu_offload=True, strict=strict, rank0_only=False) - load_model_state_dict(state_dict, _unwrap_tom(state), options, self.local_rank) - return {} - - if isinstance(state, Optimizer): - raise NotImplementedError( - "Loading a single optimizer object from a checkpoint is not supported yet with the FSDP strategy." - ) - - modules = {key: module for key, module in state.items() if has_fsdp_modules(module)} - if len(modules) == 0: - raise ValueError( - "Could not find a FSDP model in the provided checkpoint state. Please provide the model as" - " part of the state like so: `load_checkpoint(..., state={'model': model, ...})`. Make sure" - " you set up the model (and optimizers if any) through the strategy before loading the checkpoint." - ) - if len(modules) > 1: - raise ValueError( - "Found multiple FSDP models in the given state. Loading checkpoints with FSDP is" - " currently limited to a single model per checkpoint. To load multiple models, call the" - " load method for each model separately with a different path." - ) - optimizers = {key: optim for key, optim in state.items() if isinstance(optim, Optimizer)} - module_key, module = list(modules.items())[0] - module = _unwrap_tom(module) - - if _is_sharded_checkpoint(path): - options = StateDictOptions(full_state_dict=False, cpu_offload=True, strict=strict, rank0_only=False) - # Load the DCP state dict, which requires a holder state dict - converted_state, _ = _get_state_dict(state, None, options, self.local_rank) - load(converted_state, path) - load_model_state_dict(converted_state[module_key], module, options, self.local_rank) - - # Load metadata (anything not a module or optimizer) - metadata = torch.load(path / _METADATA_FILENAME) - requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys() - _validate_keys_for_strict_loading(requested_metadata_keys, metadata.keys(), strict=strict) - for key in requested_metadata_keys: - if key not in metadata: - continue - state[key] = metadata.pop(key) - # return the remaining metadata that wasn't requested as part of `state` - return metadata - - if _is_full_checkpoint(path): - options = StateDictOptions(full_state_dict=True, cpu_offload=True, strict=strict, rank0_only=False) - if not options.rank0_only or self.local_rank == 0: - map_location = "cpu" if options.cpu_offload else None - checkpoint = torch.load(str(path), mmap=True, map_location=map_location) - load_model_state_dict(checkpoint[module_key], module, options, self.local_rank) - else: - checkpoint = {} - - requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys() - _validate_keys_for_strict_loading(requested_metadata_keys, checkpoint.keys(), strict=strict) - # Load metadata (anything not a module or optimizer) - _move_state_into(source=checkpoint, destination=state, keys=requested_metadata_keys) - # return the remaining metadata that wasn't requested as part of `state` - return checkpoint - - raise ValueError( - f"The path {str(path)!r} does not point to a valid checkpoint. Make sure the path points to either a" - " directory with FSDP checkpoint shards, or a single file with a full checkpoint." - ) - - def _setup_distributed(self) -> None: - reset_seed() - self._set_world_ranks() - process_group_backend = _get_default_process_group_backend_for_device(self.root_device) - assert self.cluster_environment is not None - _init_dist_connection(self.cluster_environment, process_group_backend) - - def _set_world_ranks(self) -> None: - if self.cluster_environment is not None: - self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) - self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) - # `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail - # additionally, for some implementations, the setter is a no-op, so it's safer to access the getter - rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank - - -def _is_sharded_checkpoint(path: Path) -> bool: - """A heuristic check to determine whether the path points to a directory with checkpoint shards.""" - return path.is_dir() and (path / _METADATA_FILENAME).is_file() - - -def _is_full_checkpoint(path: Path) -> bool: - return path.is_file() - - -def _validate_executors(executors: Optional[Tuple[Union["Executor", str], ...]]) -> Optional[Tuple["Executor", ...]]: - """Converts string executors into it's respective ``Executor`` object.""" - if executors is None: - return None - from thunder import get_all_executors - - final = [] - issues = [] - all = get_all_executors() - for executor in executors: - if isinstance(executor, str): - for existing in all: - if executor == existing.name: - final.append(existing) - break - else: - issues.append(executor) - else: - final.append(executor) - if issues: - raise ValueError(f"Did not find the executors {issues} in {all}") - return tuple(final) - - -def _get_state_dict( - state: Dict[str, Any], - filter: Optional[Dict[str, Callable[[str, Any], bool]]], - options: "StateDictOptions", - rank: int, -) -> Tuple[Dict[str, Any], Dict[str, Any]]: - from thunder.distributed.checkpoint import get_model_state_dict - - # replace the modules and optimizer objects in the state with their local state dict - # and separate the user's metadata - converted_state: Dict[str, Any] = {} - metadata: Dict[str, Any] = {} - for key, obj in state.items(): - converted: Any - if isinstance(obj, Module): - converted = get_model_state_dict(_unwrap_tom(obj), options, rank) - target_dict = converted_state - elif isinstance(obj, Optimizer): - # TODO: optimizer support - converted = obj.state_dict() - target_dict = converted_state - else: # everything not a module or optimizer is considered metadata - converted = obj.state_dict() if isinstance(obj, _Stateful) else obj - target_dict = metadata - _apply_filter(key, filter or {}, converted, target_dict) - - return converted_state, metadata - - -def _unwrap_tom(obj: object) -> object: - # TODO: this unwrap won't be required when Fabric's `_unwrap_objects` supports Thunder - from thunder import ThunderModule - - if isinstance(obj, ThunderModule): - return obj._model - return obj diff --git a/examples/lit-gpt/test_ddp_thunder.py b/examples/lit-gpt/test_ddp_thunder.py deleted file mode 100644 index 860d14a8d4..0000000000 --- a/examples/lit-gpt/test_ddp_thunder.py +++ /dev/null @@ -1,39 +0,0 @@ -import pytest -import torch -from _ddp_thunder import DDPThunderStrategy - -from lightning import Fabric -# from tests.tests_fabric.helpers.runif import RunIf - - -# @RunIf(min_cuda_gpus=2, thunder=True, standalone=True) -@pytest.mark.parametrize("strategy", ["ddp", DDPThunderStrategy()]) -def test_no_backward_sync(strategy): - fabric = Fabric(devices=2, accelerator="cuda", strategy=strategy) - fabric.launch() - - model = torch.nn.Linear(1, 1, bias=False, device=fabric.device) - x = torch.randn(1, 1, device=fabric.device) - model = fabric.setup(model) - - # 6 iters, 3 grad accumulation iters - for i, enabled in enumerate((True, True, False, True, True, False), 1): - x = torch.tensor([i * (fabric.local_rank + 1)], device=fabric.device, dtype=torch.float32) - - with fabric.no_backward_sync(model, enabled): - y = model(x) - y.backward() - if not enabled: - # Math for the first 3 iters - # - # DistributedDataParallel - # (1*1+2*1+3*1 + 1*2+2*2+3*2) / 2 = 9 - # ^^^^^^^^^^^ ^^^^^^^^^^^ ^^^ - # rank0 rank1 allreduce - # - # thunder.distributed.ddp - # ((1*1+2*1) + (1*2+2*2)) / 2 + (3*1 + 3*2) / 2 = 9 - # ^^^^^^^ ^^^^^^^ ^^^ ^^^ ^^^ ^^^ - # rank0 rank1 allreduce1 rank0 rank1 allreduce2 - assert model.weight.grad.item() == (9.0 if i == 3 else 22.5) - model.weight.grad = None diff --git a/examples/lit-gpt/test_fsdp_thunder.py b/examples/lit-gpt/test_fsdp_thunder.py deleted file mode 100644 index a49b9131f7..0000000000 --- a/examples/lit-gpt/test_fsdp_thunder.py +++ /dev/null @@ -1,294 +0,0 @@ -from _fsdp_thunder import FSDPThunderStrategy, _validate_executors -from lightning.fabric import Fabric -import torch -import pytest -import re -import os -from typing import Optional, Tuple, Union -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3 - - -def test_thunder_strategy_input_parsing(): - from thunder.distributed import FSDPBucketingStrategy, FSDPType - from thunder import pythonex - - strategy = FSDPThunderStrategy(bucketing_strategy="BlOcK", executors_list=("python",), sharding_strategy="zero3") - assert strategy.bucketing_strategy is FSDPBucketingStrategy.BLOCK - assert strategy.executors_list == (pythonex,) - assert strategy.sharding_strategy is FSDPType.ZERO3 - - -def test_validate_executors(): - from thunder import pythonex, pytorch_executor - - assert _validate_executors(None) is None - assert _validate_executors((pythonex, pytorch_executor)) == (pythonex, pytorch_executor) - assert _validate_executors(("python", "torch")) == (pythonex, pytorch_executor) - assert _validate_executors(("python", pytorch_executor)) == (pythonex, pytorch_executor) - with pytest.raises(ValueError, match=re.escape("not find the executors ['foo', 'bar'] in")): - assert _validate_executors(("python", "foo", pytorch_executor, "bar")) - - -def test_save_checkpoint_invalid_settings_raise(tmp_path): - strategy = FSDPThunderStrategy(state_dict_type="full") - with pytest.raises(TypeError, match="not supported"): - strategy.save_checkpoint(tmp_path, {}, storage_options=object()) - - with pytest.raises(IsADirectoryError, match="path exists"): - strategy.save_checkpoint(tmp_path, {}) - - model = torch.nn.Linear(1, 1) - with pytest.raises(ValueError, match="Could not find"): - strategy.save_checkpoint(tmp_path / "foo", {}) - - model.use_fsdp = True - with pytest.raises(ValueError, match="Found multiple"): - strategy.save_checkpoint(tmp_path / "foo", {"model1": model, "model2": model}) - - with pytest.raises(ValueError, match="at least a model"): - strategy.load_checkpoint(tmp_path / "foo", {}) - - with pytest.raises(ValueError, match="must be a single file"): - strategy.load_checkpoint(tmp_path, model) - - optimizer = torch.optim.Adam(model.parameters()) - with pytest.raises(NotImplementedError, match="not supported"): - strategy.load_checkpoint(tmp_path, optimizer) - - with pytest.raises(ValueError, match="Found multiple"): - strategy.load_checkpoint(tmp_path / "foo", {"model1": model, "model2": model}) - - with pytest.raises(ValueError, match="Could not find"): - strategy.load_checkpoint(tmp_path / "foo", {"foo": 1}) - - -class Submodule(torch.nn.Module): - def __init__(self, h: int): - super().__init__() - self.l = torch.nn.Linear(4, h * 2, bias=False) - - def forward(self, x): - # defined just because preprocessing fails otherwise - ... - - -class MyModel(torch.nn.Module): - def __init__(self, h: int): - super().__init__() - self.register_buffer("buf", torch.tensor(0)) - self.l = torch.nn.Linear(2, h) - self.inner = Submodule(h) - - def forward(self): - # defined just because preprocessing fails otherwise - ... - - def reset_parameters(self): - self.buf = torch.empty_like(self.buf) - - -def test_materialize_meta_tensors(): - strategy = FSDPThunderStrategy() - fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy) - fabric.launch() - - with fabric.init_module(empty_init=True): - model = MyModel(2) - - model = fabric.setup(model) - # all parameters were moved - assert len(list(model.parameters())) == 3 - assert all(p.device.type == "cuda" for p in model.parameters()) - # buffers were moved too - assert model.buf.device.type == "cuda" - - -class StatefulThing: - def state_dict(self): - return {"thing": 1} - - def load_state_dict(self, state_dict): - assert state_dict == self.state_dict() - - -class TensorLike: - def __init__(self, device: Optional[Union[str, torch.device]] = None, shape: Optional[Tuple[int, ...]] = None): - self.device = torch.device(device) if device is not None else None - self.shape = torch.Size(shape) if shape is not None else None - - def __eq__(self, other): - return ( - isinstance(other, torch.Tensor) - and (self.device is None or other.device == self.device) - and (self.shape is None or other.shape == self.shape) - ) - - -def test_save_load_full_checkpoint(tmp_path): - strategy = FSDPThunderStrategy(state_dict_type="full", broadcast_from=0) - fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy) - fabric.launch() - - model = MyModel(4) - expected = model.state_dict() - - # save a sharded model - model = fabric.setup(model) - state = {"model": model, "stateful": StatefulThing(), "primitive": 123} - checkpoint_path = tmp_path / "foo" - fabric.save(checkpoint_path, state) - - # assert the file contents - if fabric.global_rank == 0: - checkpoint = torch.load(checkpoint_path) - # cpu_offload is enabled by default - assert checkpoint == { - "model": { - "buf": TensorLike("cpu", tuple()), - "inner.l.weight": TensorLike("cpu", (8, 4)), - "l.bias": TensorLike("cpu", (4,)), - "l.weight": TensorLike("cpu", (4, 2)), - }, - "stateful": {"thing": 1}, - "primitive": 123, - } - torch.testing.assert_close(checkpoint["model"], expected) - - # load its weights into a different sharded model - model = MyModel(4) - model = fabric.setup(model) - state = {"model": model, "stateful": StatefulThing(), "primitive": 321} - fabric.load(checkpoint_path, state) - - from thunder.distributed import _unshard_params - - # unshard this model's parameters to compare with the original state dict before sharding - _unshard_params(model, model.process_group_for_ddp, True) - # we loaded rank 0's weights, so this would fail in the other ranks - if fabric.global_rank == 0: - actual = model.state_dict() - # `_unshard_params` doesnt offload buffers at the moment - assert actual["buf"].device.type == "cuda" - actual["buf"] = actual["buf"].to(device="cpu") - torch.testing.assert_close(actual, expected) - assert state["primitive"] == 123 - - -def test_load_full_checkpoint_only_model(tmp_path): - strategy = FSDPThunderStrategy() - fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy) - fabric.launch() - - checkpoint_path = tmp_path / "foo" - checkpoint_path = fabric.broadcast(checkpoint_path) - if fabric.global_rank == 0: - model = MyModel(4) - expected = model.state_dict() - torch.save(expected, checkpoint_path) - fabric.barrier() - expected = torch.load(checkpoint_path) - - # before sharding - model = MyModel(4) - fabric.load_raw(checkpoint_path, model) - torch.testing.assert_close(model.state_dict(), expected) - - # after sharding - model = MyModel(4) - model = fabric.setup(model) - fabric.load_raw(checkpoint_path, model) - from thunder.distributed import _unshard_params - - # unshard this model's parameters to compare with the original state dict before sharding - _unshard_params(model, model.process_group_for_ddp, True) - actual = model.state_dict() - # `_unshard_params` doesnt offload buffers at the moment - assert actual["buf"].device.type == "cuda" - actual["buf"] = actual["buf"].to(device="cpu") - torch.testing.assert_close(actual, expected) - - -def distributed_ckpt_to_regular(path): - """From ``torch.distributed.checkpoint.format_utils.dcp_to_torch_save``.""" - from torch.distributed.checkpoint.state_dict_loader import _load_state_dict - from torch.distributed.checkpoint import FileSystemReader - - if _TORCH_GREATER_EQUAL_2_3: - from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner - else: - from torch.distributed.checkpoint._traverse import set_element - from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner - from torch.distributed.checkpoint.metadata import TensorStorageMetadata - - class _EmptyStateDictLoadPlanner(DefaultLoadPlanner): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def set_up_planner(self, state_dict, metadata, is_coordinator): - assert not state_dict - # rebuild the state dict from the metadata - for k, v in metadata.state_dict_metadata.items(): - if isinstance(v, TensorStorageMetadata): - v = torch.empty(v.size, dtype=v.properties.dtype) - if k in metadata.planner_data: - set_element(state_dict, metadata.planner_data[k], v) - else: - state_dict[k] = v - super().set_up_planner(state_dict, metadata, is_coordinator) - - state_dict = {} - storage_reader = FileSystemReader(path) - _load_state_dict(state_dict, storage_reader=storage_reader, planner=_EmptyStateDictLoadPlanner(), no_dist=True) - return state_dict - - -def test_save_load_sharded_checkpoint(tmp_path): - strategy = FSDPThunderStrategy(state_dict_type="sharded", broadcast_from=0) - fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy) - fabric.launch() - - model = MyModel(4) - expected = model.state_dict() - - # save a sharded model - model = fabric.setup(model) - state = {"model": model, "stateful": StatefulThing(), "primitive": 123} - fabric.save(tmp_path, state) - - # assert the file contents - if fabric.global_rank == 0: - assert set(os.listdir(tmp_path)) == {"meta.pt", "__1_0.distcp", "__0_0.distcp", ".metadata"} - - metadata = torch.load(tmp_path / "meta.pt") - assert metadata == {"stateful": {"thing": 1}, "primitive": 123} - - checkpoint = distributed_ckpt_to_regular(tmp_path) - # cpu_offload is enabled by default - assert checkpoint == { - "model": { - "buf": TensorLike("cpu", tuple()), - "inner.l.weight": TensorLike("cpu", (8, 4)), - "l.bias": TensorLike("cpu", (4,)), - "l.weight": TensorLike("cpu", (4, 2)), - } - } - torch.testing.assert_close(checkpoint["model"], expected) - - # load its weights into a different sharded model - model = MyModel(4) - model = fabric.setup(model) - state = {"model": model, "stateful": StatefulThing(), "primitive": 321} - fabric.load(tmp_path, state) - - from thunder.distributed import _unshard_params - - # unshard this model's parameters to compare with the original state dict before sharding - _unshard_params(model, model.process_group_for_ddp, True) - # we loaded rank 0's weights, so this would fail in the other ranks - if fabric.global_rank == 0: - actual = model.state_dict() - # `_unshard_params` doesnt offload buffers at the moment - assert actual["buf"].device.type == "cuda" - actual["buf"] = actual["buf"].to(device="cpu") - torch.testing.assert_close(actual, expected) - assert state["primitive"] == 123 diff --git a/examples/lit-gpt/train.py b/examples/lit-gpt/train.py deleted file mode 100644 index 412711ce5a..0000000000 --- a/examples/lit-gpt/train.py +++ /dev/null @@ -1,111 +0,0 @@ -import time - -import lightning as L -import torch -from torch.utils.data import DataLoader, IterableDataset - -from thunder.tests.lit_gpt_model import GPT, Config - -model_name = "open_llama_3b" -learning_rate = 6e-4 -micro_batch_size = 2 -max_iters = 50 - - -def main(compile: str = "eager", dynamic: bool = False) -> None: - fabric = L.Fabric(devices=1, precision="bf16-true") - - fabric.seed_everything(42, workers=True) # same seed for every process to init model (FSDP) - - config = Config.from_name(model_name) - print(f"Loading model with {config.__dict__}") - t0 = time.perf_counter() - with fabric.init_module(): - og_model = model = GPT(config) - print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") - - if compile == "inductor": - model = torch.compile(model, fullgraph=True, mode="reduce-overhead", dynamic=dynamic) - elif compile == "thunder": - import thunder - from thunder.executors.sdpaex import sdpa_ex - from thunder.executors.torch_compile import torch_compile_executor - - model = thunder.jit( - model, - executors=[sdpa_ex, torch_compile_executor, thunder.nvfuser_executor, thunder.pytorch_executor], - # TODO: we'd want to enable CUDAGraphs for parity with `torch.compile` but it goes OOM - ) - model.max_seq_length = og_model.max_seq_length - elif compile != "eager": - raise ValueError(compile) - - model = fabric.setup(model) - optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=1e-1, foreach=False) - optimizer = fabric.setup_optimizers(optimizer) - - train_data = DummyDataset(model.max_seq_length, dynamic) - train_dataloader = DataLoader(train_data, batch_size=micro_batch_size, num_workers=2, collate_fn=pad_collate) - train_dataloader = fabric.setup_dataloaders(train_dataloader) - - train(fabric, model, optimizer, train_dataloader) - print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") - - -def train( - fabric: L.Fabric, model: torch.nn.Module, optimizer: torch.optim.Optimizer, train_dataloader: DataLoader -) -> None: - train_iter = iter(train_dataloader) - t0 = None - assert max_iters > 5 - for i in range(max_iters): - iter_t0 = time.perf_counter() - if i == 5: # warmup - t0 = iter_t0 - input_ids, targets = next(train_iter) - - logits = model(input_ids) - logits = logits.reshape(-1, logits.size(-1)) - targets = targets.reshape(-1) - loss = torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1) - fabric.backward(loss) - optimizer.step() - optimizer.zero_grad() - - loss_item = loss.item() # synchronization - t1 = time.perf_counter() - print(f"iter {i}: loss {loss_item:.4f}, iter time: {(t1 - iter_t0) * 1000:.2f}ms, t: {input_ids.size(1)}") - print(f"Total time: {(t1 - t0):.2f}s") - - -class DummyDataset(IterableDataset): - def __init__(self, max_seq_length: int, dynamic: bool): - super().__init__() - self.max_seq_length = max_seq_length - self.dynamic = dynamic - - def __iter__(self): - while True: - if self.dynamic: - t = torch.randint(10, self.max_seq_length + 1, (1,)) - else: - t = self.max_seq_length - data = torch.randint(0, 100, (t + 1,), dtype=torch.int64) - x = data[:t] - y = data[1 : t + 1] - yield x, y - - -def pad_collate(batch): - x, y = zip(*batch) - x_padded = torch.nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value=0) - y_padded = torch.nn.utils.rnn.pad_sequence(y, batch_first=True, padding_value=-1) - return x_padded, y_padded - - -if __name__ == "__main__": - torch.set_float32_matmul_precision("high") - - from jsonargparse import CLI - - CLI(main) diff --git a/examples/lit-gpt/train_fsdp.py b/examples/lit-gpt/train_fsdp.py deleted file mode 100644 index e896d52ef3..0000000000 --- a/examples/lit-gpt/train_fsdp.py +++ /dev/null @@ -1,122 +0,0 @@ -import logging -import re -import time -from typing import Literal - -import lightning as L -import torch -from lightning.fabric.strategies import FSDPStrategy -from torch.distributed.fsdp.wrap import always_wrap_policy -from torch.utils.data import DataLoader, IterableDataset - -from _fsdp_thunder import FSDPThunderStrategy -from thunder.tests.lit_gpt_model import GPT, Block, Config - - -model_name = "open_llama_3b" -learning_rate = 6e-4 -micro_batch_size = 2 -max_iters = 50 - - -def main( - compile: str = "eager", devices: int = 2, stage: str = "2", bucketing_strategy: Literal["NONE", "BLOCK"] = "NONE" -) -> None: - fsdp_type = {"2": "ZERO2", "3": "ZERO3"}[stage] - sharding_strategy = {"2": "SHARD_GRAD_OP", "3": "FULL_SHARD"}[stage] - auto_wrap_policy = always_wrap_policy if bucketing_strategy.lower() == "none" else {Block} - strategy = ( - FSDPThunderStrategy( - sharding_strategy=fsdp_type, - bucketing_strategy=bucketing_strategy, - executors=("sdpa", "torchcompile", "nvfuser", "torch"), - ) - if compile == "thunder" - else FSDPStrategy(auto_wrap_policy=auto_wrap_policy, sharding_strategy=sharding_strategy) - ) - - fabric = L.Fabric(devices=devices, strategy=strategy, precision="bf16-true") - fabric.launch() - - fabric.seed_everything(42, workers=True) # same seed for every process to init model (FSDP) - - config = Config.from_name(model_name) - fabric.print(f"Loading model with {config.__dict__}") - t0 = time.perf_counter() - with fabric.init_module(empty_init=True): - og_model = model = GPT(config) - fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") - - if compile == "inductor": - # Suppress excessive warnings, see https://github.com/pytorch/pytorch/issues/111632 - pattern = re.compile(".*Profiler function .* will be ignored") - logging.getLogger("torch._dynamo.variables.torch").addFilter( - lambda record: not pattern.search(record.getMessage()) - ) - - model = torch.compile(model) - elif compile == "thunder": - pass # fabric.setup does this - elif compile != "eager": - raise ValueError(compile) - - model = fabric.setup(model) - if compile == "thunder": - model.max_seq_length = og_model.max_seq_length - optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=1e-1, foreach=False) - optimizer = fabric.setup_optimizers(optimizer) - - train_data = DummyDataset(model.max_seq_length) - train_dataloader = DataLoader(train_data, batch_size=micro_batch_size, num_workers=2) - train_dataloader = fabric.setup_dataloaders(train_dataloader) - - train(fabric, model, optimizer, train_dataloader) - fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") - - -def train( - fabric: L.Fabric, model: torch.nn.Module, optimizer: torch.optim.Optimizer, train_dataloader: DataLoader -) -> None: - train_iter = iter(train_dataloader) - t0 = None - assert max_iters > 5 - for i in range(max_iters): - iter_t0 = time.perf_counter() - if i == 5: # warmup - t0 = iter_t0 - input_ids, targets = next(train_iter) - - logits = model(input_ids) - logits = logits.reshape(-1, logits.size(-1)) - targets = targets.reshape(-1) - loss = torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1) - fabric.backward(loss) - optimizer.step() - optimizer.zero_grad() - - loss_item = loss.item() # synchronization - t1 = time.perf_counter() - fabric.print(f"iter {i}: loss {loss_item :.4f}, iter time: {(t1 - iter_t0) * 1000:.2f}ms") - fabric.print(f"Total time: {(t1 - t0):.2f}s") - - -class DummyDataset(IterableDataset): - def __init__(self, max_seq_length: int): - super().__init__() - self.max_seq_length = max_seq_length - - def __iter__(self): - t = self.max_seq_length - while True: - data = torch.randint(0, 100, (t + 1,), dtype=torch.int64) - x = data[:t] - y = data[1 : t + 1] - yield x, y - - -if __name__ == "__main__": - torch.set_float32_matmul_precision("high") - - from jsonargparse import CLI - - CLI(main) diff --git a/notebooks/dev_tutorials/fsdp_tutorial.ipynb b/notebooks/dev_tutorials/fsdp_tutorial.ipynb index a4f61b47c3..b8cef2a2ff 100644 --- a/notebooks/dev_tutorials/fsdp_tutorial.ipynb +++ b/notebooks/dev_tutorials/fsdp_tutorial.ipynb @@ -1764,7 +1764,7 @@ "%%writefile thunder_fsdp_simple_example.py\n", "\n", "# imports\n", - "from thunder.tests.lit_gpt_model import GPT, Config\n", + "from thunder.tests.litgpt_model import GPT, Config\n", "import torch\n", "import torch.distributed\n", "import thunder\n", diff --git a/notebooks/zero_to_thunder.ipynb b/notebooks/zero_to_thunder.ipynb index a1a888cc72..ffd6b5fe1c 100644 --- a/notebooks/zero_to_thunder.ipynb +++ b/notebooks/zero_to_thunder.ipynb @@ -312,8 +312,8 @@ } ], "source": [ - "from lit_gpt import GPT\n", - "from thunder.tests.lit_gpt_model import Config\n", + "from litgpt import GPT\n", + "from thunder.tests.litgpt_model import Config\n", "cfg = Config.from_name('Llama-2-7b-hf')\n", "cfg.n_layer = 16 # fewer layers\n", "torch.set_default_dtype(torch.bfloat16)\n", @@ -3326,7 +3326,7 @@ ], "source": [ "%%writefile zero_to_thunder_fsdp_simple_example.py\n", - "from thunder.tests.lit_gpt_model import GPT, Config\n", + "from thunder.tests.litgpt_model import GPT, Config\n", "import os\n", "import torch, torch.distributed\n", "import thunder, thunder.distributed\n", @@ -3470,7 +3470,7 @@ }, "outputs": [], "source": [ - "import lit_gpt\n", + "import litgpt\n", "def apply_rope_copy(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:\n", " head_size = x.size(-1)\n", " x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)\n", @@ -3493,7 +3493,7 @@ "\n", "Say we have a function `apply_rope` applying the RoPE transformation in PyTorch.\n", "\n", - "In thunder, we define a *meta* function that only defines the metadata (like shapes) of outputs and the actual implementation for each operator and then register the pair with our executor using the `register_operator` function and tell it to use the new symbol instead of the original function `lit_gpt.model.apply_rope`.\n" + "In thunder, we define a *meta* function that only defines the metadata (like shapes) of outputs and the actual implementation for each operator and then register the pair with our executor using the `register_operator` function and tell it to use the new symbol instead of the original function `litgpt.model.apply_rope`.\n" ] }, { @@ -3504,17 +3504,17 @@ "outputs": [], "source": [ "import torch, thunder\n", - "from thunder.tests.lit_gpt_model import GPT\n", + "from thunder.tests.litgpt_model import GPT\n", "from thunder import TensorProxy\n", "\n", "def apply_rope_impl(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:\n", - " return lit_gpt.model.apply_rope(x, cos, sin)\n", + " return litgpt.model.apply_rope(x, cos, sin)\n", "\n", "def apply_rope_meta(x: TensorProxy, cos: TensorProxy, sin: TensorProxy) -> TensorProxy:\n", " return TensorProxy(like=x)\n", "\n", "apply_rope = my_ex.register_operator('apply_rope', like=apply_rope_meta, fn=apply_rope_impl,\n", - " replaces=lit_gpt.model.apply_rope)" + " replaces=litgpt.model.apply_rope)" ] }, { @@ -3569,7 +3569,7 @@ "with torch.device('cuda'): m = GPT.from_name('llama2-like'); Q = torch.randn(2, 128, 4096, 16)\n", "\n", "def test_apply_rope(x, m):\n", - " return lit_gpt.model.apply_rope(x, m.cos, m.sin)\n", + " return litgpt.model.apply_rope(x, m.cos, m.sin)\n", "\n", "thunder_apply_rope = thunder.jit(test_apply_rope, executors=(my_ex,) + thunder.get_default_executors()) \n", "\n", diff --git a/requirements/notebooks.txt b/requirements/notebooks.txt index 47a14902ca..7c37419cd6 100644 --- a/requirements/notebooks.txt +++ b/requirements/notebooks.txt @@ -1 +1,3 @@ ipython[all] ==8.22.2 + +litgpt @ git+https://github.com/Lightning-AI/lit-gpt@24d5eba1724c953b7506edc041a7da1ce226c129 diff --git a/requirements/test.txt b/requirements/test.txt index a1402bd69b..621ebaa7c5 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -11,11 +11,11 @@ expecttest ==0.2.1 # for test_ddp.py hypothesis ==6.99.10 # for test_ddp.py numpy # for test_ops.py einops # for test_einops.py -lit_gpt @ git+https://github.com/Lightning-AI/lit-gpt@f241d94df59d82b2017bfdcd3800ac8779eb45f5 -absl-py # for test_parametrized.py in examples/lit-gpt -pandas # for test_parametrized.py in examples/lit-gpt -xlsxwriter # for test_parametrized.py in examples/lit-gpt -jsonargparse # for benchmarking_litgpt.py in thunder/benchmarks +litgpt @ git+https://github.com/Lightning-AI/lit-gpt@24d5eba1724c953b7506edc041a7da1ce226c129 +absl-py # thunder/benchmarks/test_benchmark_litgpt.py +pandas # thunder/benchmarks/test_benchmark_litgpt.py +xlsxwriter # thunder/benchmarks/test_benchmark_litgpt.py +jsonargparse # thunder/benchmarks/benchmark_litgpt.py # Installs JAX on Linux and MacOS jaxlib; sys_platform == 'linux' or sys_platform == 'darwin' # required for jax, see https://github.com/google/jax#installation diff --git a/thunder/benchmarks/__init__.py b/thunder/benchmarks/__init__.py index e9ba261c45..f4b0f2ac03 100644 --- a/thunder/benchmarks/__init__.py +++ b/thunder/benchmarks/__init__.py @@ -24,9 +24,9 @@ import thunder.core.devices as Devices from thunder.core.transforms import grad, clear_grads, populate_grads import thunder.executors as executors -from thunder.tests import nanogpt_model, hf_bart_self_attn, lit_gpt_model +from thunder.tests import nanogpt_model, hf_bart_self_attn, litgpt_model from thunder.tests.make_tensor import make_tensor, make_tensor_like -from thunder.tests.lit_gpt_model import Config as LitGPTConfig +from thunder.tests.litgpt_model import Config as LitGPTConfig # List of all benchmarks benchmarks: list = [] @@ -1875,7 +1875,7 @@ class LlamaMLPBenchmark(Benchmark, metaclass=UserFacingBenchmarkMeta): _args = ( BenchmarkArg( name="config", - description="The Lit-GPT config to use. Default is 'Llama-2-7b-hf'. See the lit_gpt_model.py for details.", + description="The Lit-GPT config to use. Default is 'Llama-2-7b-hf'. See the litgpt_model.py for details.", ), BenchmarkArg( name="batchdims", @@ -1935,7 +1935,7 @@ def make_batch(self) -> tuple[list, dict]: def fn(self) -> Callable: module = ( - lit_gpt_model.LLaMAMLP(self.config) + litgpt_model.LLaMAMLP(self.config) .to(device=self.device, dtype=self.tdtype) .requires_grad_(self.requires_grad) ) @@ -1946,7 +1946,7 @@ class LitGPTCausalSelfAttentionBenchmark(Benchmark, metaclass=UserFacingBenchmar _args = ( BenchmarkArg( name="config", - description="The Lit-GPT config to use. Default is 'Llama-2-7b-hf'. See the lit_gpt_model.py for details.", + description="The Lit-GPT config to use. Default is 'Llama-2-7b-hf'. See the litgpt_model.py for details.", ), BenchmarkArg( name="batchdims", @@ -2005,7 +2005,7 @@ def make_batch(self) -> tuple[list, dict]: def fn(self) -> Callable: module = ( - lit_gpt_model.CausalSelfAttention(self.config) + litgpt_model.CausalSelfAttention(self.config) .to(device=self.device, dtype=self.tdtype) .requires_grad_(self.requires_grad) ) @@ -2086,7 +2086,7 @@ def make_batch(self) -> tuple[list, dict]: def fn(self) -> Callable: module = ( - lit_gpt_model.RMSNorm(self.size, self.dim, self.eps) + litgpt_model.RMSNorm(self.size, self.dim, self.eps) .to(device=self.device, dtype=self.tdtype) .requires_grad_(self.requires_grad) ) @@ -2168,7 +2168,7 @@ def make_batch(self) -> tuple[list, dict]: def fn(self) -> Callable: gpt = ( - lit_gpt_model.GPT(self.config) + litgpt_model.GPT(self.config) .to(device=self.device, dtype=self.model_tdtype) .requires_grad_(self.requires_grad) ) @@ -2199,7 +2199,7 @@ def __init__(self, config, use_apex) -> None: super().__init__() self.config = config - self.apply_rope = lit_gpt_model.apply_rope + self.apply_rope = litgpt_model.apply_rope self.use_apex = use_apex def forward( @@ -2254,7 +2254,7 @@ class LlamaQKVSplitRopeBenchmark(Benchmark, metaclass=UserFacingBenchmarkMeta): _args = ( BenchmarkArg( name="config", - description="The Lit-GPT config to use. Default is 'Llama-2-7b-hf'. See the lit_gpt_model.py for details.", + description="The Lit-GPT config to use. Default is 'Llama-2-7b-hf'. See the litgpt_model.py for details.", ), BenchmarkArg( name="batchdims", @@ -2610,7 +2610,7 @@ def __init__( dtype: dtypes.dtype = thunder.bfloat16, requires_grad: bool = True, ) -> None: - from thunder.tests.lit_gpt_model import Config + from thunder.tests.litgpt_model import Config litgptconfig = Config.from_name(config) if not isinstance(config, Config) else config nanogptconfig = NanoGPTConfig( @@ -2793,7 +2793,7 @@ def __init__( # Sets required benchmark parameters self.devices: list[str] = [device] - self.cos, self.sin = lit_gpt_model.build_rope_cache( + self.cos, self.sin = litgpt_model.build_rope_cache( seq_len=seq_length, n_elem=self.config.rope_n_elem, device=self.device ) @@ -2806,9 +2806,7 @@ def make_batch(self) -> tuple[list, dict]: def fn(self) -> Callable: model = ( - lit_gpt_model.Block(self.config) - .to(device=self.device, dtype=self.tdtype) - .requires_grad_(self.requires_grad) + litgpt_model.Block(self.config).to(device=self.device, dtype=self.tdtype).requires_grad_(self.requires_grad) ) return model diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py index 9120584989..877f99f38e 100644 --- a/thunder/benchmarks/benchmark_litgpt.py +++ b/thunder/benchmarks/benchmark_litgpt.py @@ -7,7 +7,7 @@ import torch.distributed as torch_dist import thunder -from thunder.tests.lit_gpt_model import Config, GPT, Block +from thunder.tests.litgpt_model import Config, GPT, Block from lightning.fabric.utilities.throughput import measure_flops from lightning.fabric.utilities import Throughput diff --git a/thunder/benchmarks/distributed.py b/thunder/benchmarks/distributed.py index c46a699ca5..3437c47e30 100644 --- a/thunder/benchmarks/distributed.py +++ b/thunder/benchmarks/distributed.py @@ -14,7 +14,7 @@ LitGPTBenchmark, LitGPTConfig, ) -from thunder.tests.lit_gpt_model import name_to_config +from thunder.tests.litgpt_model import name_to_config from thunder.distributed import FSDPBucketingStrategy from thunder.distributed import FSDPType @@ -299,7 +299,7 @@ def parse_args() -> argparse.Namespace: from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from thunder.benchmarks import get_default_torch_fsdp_executor from thunder.tests.nanogpt_model import Block as NanoGPTBlock - from thunder.tests.lit_gpt_model import Block as GPTBlock + from thunder.tests.litgpt_model import Block as GPTBlock sharding_strategy = ShardingStrategy.SHARD_GRAD_OP auto_wrap_policies = ( diff --git a/thunder/benchmarks/targets.py b/thunder/benchmarks/targets.py index 9f02b23d35..9b38e26c3d 100644 --- a/thunder/benchmarks/targets.py +++ b/thunder/benchmarks/targets.py @@ -39,7 +39,7 @@ thunder_sdpa_torch_compile_nvfuser_executor, ) -from thunder.tests.lit_gpt_model import Config as LitGPTConfig +from thunder.tests.litgpt_model import Config as LitGPTConfig APEX_FUSED_ROPE_AVAILABLE: bool = package_available("fused_rotary_positional_embedding") diff --git a/examples/lit-gpt/test_parametrized.py b/thunder/benchmarks/test_benchmark_litgpt.py similarity index 66% rename from examples/lit-gpt/test_parametrized.py rename to thunder/benchmarks/test_benchmark_litgpt.py index 5e658b6447..cb76e48221 100644 --- a/examples/lit-gpt/test_parametrized.py +++ b/thunder/benchmarks/test_benchmark_litgpt.py @@ -1,4 +1,4 @@ -''' +""" Script to run all lit-GPT models available as a parametrized test using abseil's unittest framework. Runs a parametrized product over all configs specified, compiler options, distributed modes etc. Uses environment variables to modify default behavior @@ -8,7 +8,7 @@ between each test. BENCHMARK_OUT_FORMAT - use this env variable to control the format in which the results are presented. Uses 'xlsx' by default. Supported: 'none', 'print', 'xlsx'. -''' +""" import torch from absl.testing import parameterized @@ -20,21 +20,18 @@ import pandas as pd from datetime import datetime + class Runner: - ''' + """ Benchmark Runner class to a) Launch the training benchmarking run, b) Store results from all tests, c) Compile results as xlsx file - ''' - - def __init__(self, - benchmark_file, - mid_benchmark_out, - output_format): + """ + def __init__(self, benchmark_file, mid_benchmark_out, output_format): self.dataframe_data = [] - self.json_file_path = '/tmp/benchmark_litgpt_data.json' + self.json_file_path = "/tmp/benchmark_litgpt_data.json" self.benchmark_file = benchmark_file self.mid_benchmark_out = mid_benchmark_out self.output_format = output_format @@ -44,40 +41,64 @@ def __enter__(self): def add_to_dataframe(self): if self.perf_metrics_dict: - if 'tokens_per_sec_per_gpu' not in self.perf_metrics_dict.keys(): #In case of OutofMemory error, this is already marked 'OOM' - self.perf_metrics_dict['tokens_per_sec_per_gpu'] = self.perf_metrics_dict['tokens_per_sec'] / self.perf_metrics_dict['Num GPUS'] + if ( + "tokens_per_sec_per_gpu" not in self.perf_metrics_dict.keys() + ): # In case of OutofMemory error, this is already marked 'OOM' + self.perf_metrics_dict["tokens_per_sec_per_gpu"] = ( + self.perf_metrics_dict["tokens_per_sec"] / self.perf_metrics_dict["Num GPUS"] + ) self.dataframe_data.append(self.perf_metrics_dict) def complete_dataframe(self, is_teardown): if not self.dataframe_data: # The benchmark probably failed return - #Called when tearing down the parametrized test - #This generates a summarized dataframe for each perf metric and saves as a xlsx file + # Called when tearing down the parametrized test + # This generates a summarized dataframe for each perf metric and saves as a xlsx file df = pd.DataFrame(self.dataframe_data) - df['Sharding Size'] = df['Sharding Size'].fillna('none') #Convert None Type to string so that pivot table can group. - index_list = ['model_name', 'Num GPUS', 'Seq Len', 'Micro BS', 'Global BS', 'GA', 'Distributed Mode', 'Sharding Size'] + df["Sharding Size"] = df["Sharding Size"].fillna( + "none" + ) # Convert None Type to string so that pivot table can group. + index_list = [ + "model_name", + "Num GPUS", + "Seq Len", + "Micro BS", + "Global BS", + "GA", + "Distributed Mode", + "Sharding Size", + ] - self.iter_time_df = df.pivot_table(index=index_list, columns='compiler', values='average_iter_time', aggfunc='first').reset_index() - self.tokens_per_sec_df = df.pivot_table(index=index_list, columns='compiler', values='tokens_per_sec', aggfunc='first').reset_index() - self.tokens_per_sec_per_gpu_df = df.pivot_table(index=index_list, columns='compiler', values='tokens_per_sec_per_gpu', aggfunc='first').reset_index() - self.memory_used_GB_df = df.pivot_table(index=index_list, columns='compiler', values='memory_used_GB', aggfunc='first').reset_index() + self.iter_time_df = df.pivot_table( + index=index_list, columns="compiler", values="average_iter_time", aggfunc="first" + ).reset_index() + self.tokens_per_sec_df = df.pivot_table( + index=index_list, columns="compiler", values="tokens_per_sec", aggfunc="first" + ).reset_index() + self.tokens_per_sec_per_gpu_df = df.pivot_table( + index=index_list, columns="compiler", values="tokens_per_sec_per_gpu", aggfunc="first" + ).reset_index() + self.memory_used_GB_df = df.pivot_table( + index=index_list, columns="compiler", values="memory_used_GB", aggfunc="first" + ).reset_index() if self.output_format == "xlsx": - output_ext = {'xlsx': '.xlsx', }[self.output_format] + output_ext = { + "xlsx": ".xlsx", + }[self.output_format] if not is_teardown: - filename = 'examples/lit-gpt/mid_output_parameterized_results' + str(output_ext) + filename = "mid_output_parameterized_results" + str(output_ext) else: - current_time = datetime.now().strftime('%Y-%m-%d_%H-%M') + current_time = datetime.now().strftime("%Y-%m-%d_%H-%M") filename = f"{current_time}_litgpt_benchmark" + str(output_ext) - filename = 'examples/lit-gpt/' + str(filename) - - with pd.ExcelWriter(filename, engine='xlsxwriter') as writer: - self.iter_time_df.to_excel(writer, sheet_name='Average Iter Time (ms)') - self.tokens_per_sec_df.to_excel(writer, sheet_name='Tokens per sec') - self.tokens_per_sec_per_gpu_df.to_excel(writer, sheet_name='Tokens per sec per GPU') - self.memory_used_GB_df.to_excel(writer, sheet_name='Memory allocated GB') - elif self.output_format == 'print': + + with pd.ExcelWriter(filename, engine="xlsxwriter") as writer: + self.iter_time_df.to_excel(writer, sheet_name="Average Iter Time (ms)") + self.tokens_per_sec_df.to_excel(writer, sheet_name="Tokens per sec") + self.tokens_per_sec_per_gpu_df.to_excel(writer, sheet_name="Tokens per sec per GPU") + self.memory_used_GB_df.to_excel(writer, sheet_name="Memory allocated GB") + elif self.output_format == "print": print("\nAVERAGE ITERATION TIME (ms)") print(self.iter_time_df) print("\nTHROUGHPUT (tokens/s)") @@ -91,12 +112,24 @@ def run_benchmark(self, kwargs): command_list = [] for key, val in kwargs.items(): command_list.append("--" + str(key) + "=" + str(val)) - if kwargs['distributed_mode'] != 'none': + if kwargs["distributed_mode"] != "none": nproc_per_node = torch.cuda.device_count() - subprocess_cmd = ["torchrun", f"--nproc_per_node={nproc_per_node}", "--nnodes=1", "{}".format(self.benchmark_file), "--return_metrics_as_json=True", "--json_path={}".format(self.json_file_path)] + subprocess_cmd = [ + "torchrun", + f"--nproc_per_node={nproc_per_node}", + "--nnodes=1", + f"{self.benchmark_file}", + "--return_metrics_as_json=True", + f"--json_path={self.json_file_path}", + ] subprocess_cmd.extend(command_list) else: - subprocess_cmd = ["python", "{}".format(self.benchmark_file), "--return_metrics_as_json=True", "--json_path={}".format(self.json_file_path)] + subprocess_cmd = [ + "python", + f"{self.benchmark_file}", + "--return_metrics_as_json=True", + f"--json_path={self.json_file_path}", + ] subprocess_cmd.extend(command_list) print(f'Running {" ".join(subprocess_cmd)!r}') @@ -104,13 +137,13 @@ def run_benchmark(self, kwargs): self.perf_metrics_dict = {} if os.path.exists(self.json_file_path): - with open(self.json_file_path, 'r') as file: + with open(self.json_file_path) as file: self.perf_metrics_dict = json.load(file) # Cleanup after the benchmark finishes. It might have failed before creating this os.remove(self.json_file_path) if proc_output.returncode: - if 'CUDA out of memory' in proc_output.stdout or "CUDA error: out of memory" in proc_output.stderr: + if "CUDA out of memory" in proc_output.stdout or "CUDA error: out of memory" in proc_output.stderr: defaultdict_oom = defaultdict(lambda: "OOM") defaultdict_oom.update(self.perf_metrics_dict) self.perf_metrics_dict = defaultdict_oom @@ -124,26 +157,28 @@ def run_benchmark(self, kwargs): class Test(parameterized.TestCase): - @classmethod def setUpClass(cls): - super(Test, cls).setUpClass() + super().setUpClass() def get_installed_thunder_path(): import thunder + thunder_init = thunder.__file__ - thunder_benchmark_file = str(thunder_init).replace('__init__.py', 'benchmarks/benchmark_litgpt.py') + thunder_benchmark_file = str(thunder_init).replace("__init__.py", "benchmarks/benchmark_litgpt.py") return thunder_benchmark_file benchmark_file = os.getenv("BENCHMARK_FILE", get_installed_thunder_path()) mid_benchmark_out = bool(os.getenv("MID_BENCHMARK_OUT", 0)) - output_format = str(os.getenv("BENCHMARK_OUT_FORMAT", "xlsx")) # Can take none, print, xlsx as of 03/12 - cls.runner = Runner(benchmark_file=benchmark_file, mid_benchmark_out=mid_benchmark_out, output_format=output_format) + output_format = str(os.getenv("BENCHMARK_OUT_FORMAT", "xlsx")) # Can take none, print, xlsx as of 03/12 + cls.runner = Runner( + benchmark_file=benchmark_file, mid_benchmark_out=mid_benchmark_out, output_format=output_format + ) @classmethod def tearDownClass(cls): cls.runner.complete_dataframe(is_teardown=True) - super(Test, cls).tearDownClass() + super().tearDownClass() # @parameterized.product( # (dict(distributed_mode = "fsdp", shard_mode = "zero2"), @@ -184,16 +219,23 @@ def tearDownClass(cls): # ) @parameterized.product( - distributed_mode = ("fsdp", ), - shard_mode = ("zero2", ), - model_name = ("Llama-2-7b-hf", ), - micro_batch_size = (1, 4, ), - compile = ("eager", "inductor", "thunder", "thunder_inductor",) + distributed_mode=("fsdp",), + shard_mode=("zero2",), + model_name=("Llama-2-7b-hf",), + micro_batch_size=( + 1, + 4, + ), + compile=( + "eager", + "inductor", + "thunder", + "thunder_inductor", + ), ) - def test(self, **kwargs): - kwargs['nsys_enabled'] = False - kwargs['dynamic'] = False + kwargs["nsys_enabled"] = False + kwargs["dynamic"] = False self.__file__ = __file__ try: @@ -210,5 +252,6 @@ def test(self, **kwargs): else: self.fail(run_msg) -if __name__ == '__main__': + +if __name__ == "__main__": absltest.main() diff --git a/thunder/tests/lit_gpt_model.py b/thunder/tests/litgpt_model.py similarity index 82% rename from thunder/tests/lit_gpt_model.py rename to thunder/tests/litgpt_model.py index 57a85089bc..23b51545a0 100644 --- a/thunder/tests/lit_gpt_model.py +++ b/thunder/tests/litgpt_model.py @@ -1,4 +1,4 @@ -"""Taken from https://github.com/Lightning-AI/lit-gpt/blob/main/lit_gpt/model.py""" +"""Taken from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py""" import torch import torch.nn as nn @@ -18,9 +18,9 @@ rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="RMSNorm", + norm_class_name="RMSNorm", norm_eps=1e-6, - _mlp_class="LLaMAMLP", + mlp_class_name="LLaMAMLP", intermediate_size=1376, ), dict( @@ -34,8 +34,8 @@ rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="RMSNorm", - _mlp_class="LLaMAMLP", + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", intermediate_size=11008, rope_condense_ratio=4, ), @@ -49,8 +49,8 @@ rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="RMSNorm", - _mlp_class="LLaMAMLP", + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", intermediate_size=1376, ), dict( @@ -87,9 +87,9 @@ rotary_percentage=1.0, parallel_residual=False, bias=False, - _norm_class="RMSNorm", + norm_class_name="RMSNorm", norm_eps=1e-05, - _mlp_class="LLaMAMLP", + mlp_class_name="LLaMAMLP", intermediate_size=1376, rope_base=1000000, ), @@ -104,9 +104,9 @@ n_query_groups=8, parallel_residual=False, bias=False, - _norm_class="RMSNorm", + norm_class_name="RMSNorm", norm_eps=1e-05, - _mlp_class="LLaMAMoE", + mlp_class_name="LLaMAMoE", intermediate_size=224, rope_base=1000000, n_expert=8, @@ -150,21 +150,20 @@ def reset_parameters(self) -> None: torch.nn.init.zeros_(self.v) -import lit_gpt -import lit_gpt.rmsnorm +import litgpt # override for operator workarounds -lit_gpt.model.KVCache = OverridenKVCache +litgpt.model.KVCache = OverridenKVCache # add the testing configurations -lit_gpt.config.name_to_config.update(name_to_config) -name_to_config.update(lit_gpt.config.name_to_config) +litgpt.config.name_to_config.update(name_to_config) +name_to_config.update(litgpt.config.name_to_config) # manually expose for backwards compatibility -Config = lit_gpt.Config -GPT = lit_gpt.GPT -RMSNorm = lit_gpt.rmsnorm.RMSNorm -CausalSelfAttention = lit_gpt.model.CausalSelfAttention -LLaMAMLP = lit_gpt.model.LLaMAMLP -build_rope_cache = lit_gpt.model.build_rope_cache -apply_rope = lit_gpt.model.apply_rope -Block = lit_gpt.model.Block +Config = litgpt.Config +GPT = litgpt.GPT +RMSNorm = litgpt.model.RMSNorm +CausalSelfAttention = litgpt.model.CausalSelfAttention +LLaMAMLP = litgpt.model.LLaMAMLP +build_rope_cache = litgpt.model.build_rope_cache +apply_rope = litgpt.model.apply_rope +Block = litgpt.model.Block diff --git a/thunder/tests/test_interpreter.py b/thunder/tests/test_interpreter.py index 447647b8bb..3f0dfb0a94 100644 --- a/thunder/tests/test_interpreter.py +++ b/thunder/tests/test_interpreter.py @@ -1337,7 +1337,7 @@ def foo(a): def foo(): # test relative import - from .lit_gpt_model import Config + from .litgpt_model import Config return Config @@ -1345,18 +1345,18 @@ def foo(): def foo(): # test relative import - from . import lit_gpt_model + from . import litgpt_model - return lit_gpt_model.Config + return litgpt_model.Config assert jit(foo)() is foo() # reload is implemented using exec of the module - from . import lit_gpt_model + from . import litgpt_model import importlib - importlib.reload(lit_gpt_model) - assert hasattr(lit_gpt_model, "GPT") + importlib.reload(litgpt_model) + assert hasattr(litgpt_model, "GPT") def test_locals_lookaside(jit): @@ -3071,7 +3071,7 @@ def test_nanogpt(jit): def test_litgpt(jit): from thunder.benchmarks import LitGPTBenchmark - from thunder.tests.lit_gpt_model import Config + from thunder.tests.litgpt_model import Config cfg: Config = Config.from_name("gpt-neox-like") bench = LitGPTBenchmark(config=cfg, device="cpu", dtype=torch.bfloat16, requires_grad=True) diff --git a/thunder/tests/test_jit_general.py b/thunder/tests/test_jit_general.py index 28fc8a54fc..3671bd12e7 100644 --- a/thunder/tests/test_jit_general.py +++ b/thunder/tests/test_jit_general.py @@ -16,7 +16,7 @@ import thunder from thunder.core.interpreter import is_jitting, InterpreterError -from thunder.tests import lit_gpt_model +from thunder.tests import litgpt_model import thunder.clang as clang from thunder.core.options import INTERPRETATION_OPTIONS, CACHE_OPTIONS import thunder.torch as ltorch @@ -505,7 +505,7 @@ def h(d, c): def test_litgpt(): from thunder.benchmarks import LitGPTBenchmark - from thunder.tests.lit_gpt_model import Config + from thunder.tests.litgpt_model import Config cfg: Config = Config.from_name("gpt-neox-like") bench = LitGPTBenchmark(config=cfg, device="cpu", dtype=torch.bfloat16, requires_grad=True) @@ -627,16 +627,16 @@ def test_litgpt_variants(name, device): device = torch.device(device) x = torch.randint(0, 200, (5, 5), device=device) - config = lit_gpt_model.Config.from_name(name) + config = litgpt_model.Config.from_name(name) with device: - reference = lit_gpt_model.GPT(config) + reference = litgpt_model.GPT(config) expected_logits = reference(x) expected_logits.sum().backward() with device: - model = lit_gpt_model.GPT(config) + model = litgpt_model.GPT(config) model.load_state_dict(reference.state_dict()) tom = thunder.jit(model, executors=nvfuserex if device.type == "cuda" else torchex) actual_logits = tom(x) @@ -677,10 +677,10 @@ def test_litgpt_variants_kvcache(name, device): device = torch.device(device) x = torch.randint(0, 200, (1, 2), device=device) - config = lit_gpt_model.Config.from_name(name) + config = litgpt_model.Config.from_name(name) with device: - model = lit_gpt_model.GPT(config) + model = litgpt_model.GPT(config) model.max_seq_length = 3 for p in model.parameters():