diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 896d8f927e..7039d38cf5 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -76,7 +76,7 @@ jobs: name: 'PaddlePaddle' runs-on: ubuntu-latest container: - image: nvcr.io/nvidia/paddlepaddle:24.07-py3 + image: nvcr.io/nvidia/paddlepaddle:24.10-py3 options: --user root steps: - name: 'Checkout' diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index b4eeefa70b..4762cccee6 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -17,8 +17,8 @@ jobs: uses: actions/checkout@v3 - name: 'Install dependencies' run: | - pip install sphinx==5.1.1 sphinx_rtd_theme==1.0.0 nbsphinx==0.8.10 IPython ipython_genutils==0.2.0 ipywidgets==8.0.2 astroid==2.15.7 - pip install breathe==4.34.0 sphinx-autoapi==2.0.1 + pip install sphinx==8.1.3 sphinx_rtd_theme==3.0.1 nbsphinx==0.9.5 IPython ipython_genutils==0.2.0 ipywidgets==8.0.2 astroid==3.3.2 + pip install breathe==4.35.0 sphinx-autoapi==3.3.2 sudo apt-get install -y pandoc graphviz doxygen export GIT_SHA=$(git show-ref --hash HEAD) - name: 'Build docs' diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index c7a85029e6..c2317c6509 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -39,6 +39,7 @@ jobs: || github.actor == 'pggPL' || github.actor == 'vasunvidia' || github.actor == 'erhoo82' + || github.actor == 'kocchop' ) steps: - name: Check if comment is issued by authorized person diff --git a/.gitignore b/.gitignore index 6890911c14..9b61454e21 100644 --- a/.gitignore +++ b/.gitignore @@ -22,9 +22,7 @@ __pycache__ .hypothesis .devcontainer.json tests/cpp/build/ -docs/_build .ipynb_checkpoints -docs/doxygen *.log CMakeFiles/CMakeSystem.cmake sdist/ @@ -40,3 +38,4 @@ dist/ downloads/ .pytest_cache/ compile_commands.json +.nfs diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt index 0eed1a29ef..feaae22bac 100644 --- a/build_tools/VERSION.txt +++ b/build_tools/VERSION.txt @@ -1 +1 @@ -1.12.0 +1.13.0 diff --git a/build_tools/paddle.py b/build_tools/paddle.py index f410682875..a68d73956e 100644 --- a/build_tools/paddle.py +++ b/build_tools/paddle.py @@ -25,7 +25,7 @@ def setup_paddle_extension( # Source files csrc_source_files = Path(csrc_source_files) sources = [ - csrc_source_files / "extensions.cu", + csrc_source_files / "extensions.cpp", csrc_source_files / "common.cpp", csrc_source_files / "custom_ops.cu", ] diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index 9152229d2f..575b7bee79 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -11,7 +11,6 @@ from .utils import ( all_files_in_dir, cuda_archs, - cuda_path, cuda_version, ) @@ -27,11 +26,8 @@ def setup_pytorch_extension( csrc_source_files = Path(csrc_source_files) extensions_dir = csrc_source_files / "extensions" sources = [ - csrc_source_files / "common.cu", + csrc_source_files / "common.cpp", csrc_source_files / "ts_fp8_op.cpp", - csrc_source_files / "userbuffers" / "ipcsocket.cc", - csrc_source_files / "userbuffers" / "userbuffers.cu", - csrc_source_files / "userbuffers" / "userbuffers-host.cpp", ] + all_files_in_dir(extensions_dir) # Header files @@ -85,19 +81,14 @@ def setup_pytorch_extension( continue # Already handled nvcc_flags.extend(["-gencode", f"arch=compute_{arch},code=sm_{arch}"]) - # Libraries - library_dirs = [] - libraries = [] - if bool(int(os.getenv("NVTE_UB_WITH_MPI", 0))): + if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): assert ( os.getenv("MPI_HOME") is not None - ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1" - mpi_home = Path(os.getenv("MPI_HOME")) - include_dirs.append(mpi_home / "include") + ), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!" + mpi_path = Path(os.getenv("MPI_HOME")) + include_dirs.append(mpi_path / "include") cxx_flags.append("-DNVTE_UB_WITH_MPI") nvcc_flags.append("-DNVTE_UB_WITH_MPI") - library_dirs.append(mpi_home / "lib") - libraries.append("mpi") # Construct PyTorch CUDA extension sources = [str(path) for path in sources] @@ -112,6 +103,4 @@ def setup_pytorch_extension( "cxx": cxx_flags, "nvcc": nvcc_flags, }, - libraries=[str(lib) for lib in libraries], - library_dirs=[str(lib_dir) for lib_dir in library_dirs], ) diff --git a/docs/.gitignore b/docs/.gitignore new file mode 100644 index 0000000000..409af2d74e --- /dev/null +++ b/docs/.gitignore @@ -0,0 +1,3 @@ +_build +doxygen +sphinx_rtd_theme \ No newline at end of file diff --git a/docs/Makefile b/docs/Makefile index d4bb2cbb9e..800eeea78a 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -16,5 +16,10 @@ help: # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). -%: Makefile - @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) +%: Makefile sphinx_rtd_theme + PYTHONPATH=sphinx_rtd_theme:$(PYTHONPATH) $(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +# Patch Sphinx RTD theme 3.0.1 to add version selector in sidebar +sphinx_rtd_theme: + git clone --depth=1 -b 3.0.1 --single-branch https://github.com/readthedocs/sphinx_rtd_theme.git + bash -c "cd sphinx_rtd_theme; git apply ../version_select.patch" diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index b097f14475..ba4e7db352 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -51,3 +51,7 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.moe_permute .. autoapifunction:: transformer_engine.pytorch.moe_unpermute + +.. autoapifunction:: transformer_engine.pytorch.initialize_ub + +.. autoapifunction:: transformer_engine.pytorch.destroy_ub diff --git a/docs/conf.py b/docs/conf.py index 7a50ce76cf..7d2d4ea7b9 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -2,38 +2,30 @@ # # See LICENSE for license information. +import datetime import os -import sys -import sphinx_rtd_theme -from sphinx.ext.autodoc.mock import mock -from sphinx.ext.autodoc import between, ClassDocumenter, AttributeDocumenter -from sphinx.util import inspect -from builtins import str -from enum import Enum -import re +import pathlib import subprocess -from pathlib import Path -from datetime import date - -te_path = os.path.dirname(os.path.realpath(__file__)) +from builtins import str -with open(te_path + "/../build_tools/VERSION.txt", "r") as f: - te_version = f.readline().strip() +# Basic project info +project = "Transformer Engine" +author = "NVIDIA CORPORATION & AFFILIATES" +# Copyright statement release_year = 2022 - -current_year = date.today().year +current_year = datetime.date.today().year if current_year == release_year: copyright_year = release_year else: copyright_year = str(release_year) + "-" + str(current_year) +copyright = f"{copyright_year}, NVIDIA CORPORATION & AFFILIATES. All rights reserved." -project = "Transformer Engine" -copyright = "{}, NVIDIA CORPORATION & AFFILIATES. All rights reserved.".format(copyright_year) -author = "NVIDIA CORPORATION & AFFILIATES" +# Transformer Engine root directory +root_path = pathlib.Path(__file__).resolve().parent.parent +# Git hash git_sha = os.getenv("GIT_SHA") - if not git_sha: try: git_sha = ( @@ -44,31 +36,16 @@ ) except: git_sha = "0000000" - git_sha = git_sha[:7] if len(git_sha) > 7 else git_sha -if "dev" in te_version: - version = str(te_version + "-" + git_sha) +# Version +with open(root_path / "build_tools" / "VERSION.txt", "r") as f: + _raw_version = f.readline().strip() +if "dev" in _raw_version: + version = str(_raw_version + "-" + git_sha) else: - version = str(te_version) -release = te_version - -# hack: version is used for html creation, so put the version picker -# link here as well: -option_on = " selected" -option_off = "" -release_opt = option_on -option_nr = 0 -version = ( - version - + """
-Version select: """.format( - option_nr, release_opt - ) -) + version = str(_raw_version) +release = _raw_version # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration @@ -92,12 +69,10 @@ pygments_style = "sphinx" - # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output html_theme = "sphinx_rtd_theme" -html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] html_static_path = ["_static"] html_show_sphinx = False @@ -106,7 +81,12 @@ "css/nvidia_footer.css", ] -html_theme_options = {"display_version": True, "collapse_navigation": False, "logo_only": False} +html_theme_options = { + "collapse_navigation": False, + "logo_only": False, + "version_selector": False, + "language_selector": False, +} napoleon_custom_sections = [ ("Parallelism parameters", "params_style"), @@ -116,8 +96,8 @@ ("FP8-related parameters", "params_style"), ] -breathe_projects = {"TransformerEngine": os.path.abspath("doxygen/xml/")} +breathe_projects = {"TransformerEngine": root_path / "docs" / "doxygen" / "xml"} breathe_default_project = "TransformerEngine" autoapi_generate_api_docs = False -autoapi_dirs = ["../transformer_engine"] +autoapi_dirs = [root_path / "transformer_engine"] diff --git a/docs/version_select.patch b/docs/version_select.patch new file mode 100644 index 0000000000..75f29fff81 --- /dev/null +++ b/docs/version_select.patch @@ -0,0 +1,21 @@ +diff --git a/sphinx_rtd_theme/layout.html b/sphinx_rtd_theme/layout.html +index e6a38b1..579eaec 100644 +--- a/sphinx_rtd_theme/layout.html ++++ b/sphinx_rtd_theme/layout.html +@@ -124,6 +124,16 @@ + {%- endif %} + + ++ {# Show TE version and version selector #} ++
++ {{ version }} ++
++ Version select: ++
++ + {%- if READTHEDOCS or DEBUG %} + {%- if theme_version_selector or theme_language_selector %} +
diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index 9efec6f2e5..db3aa31951 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -18,5 +18,7 @@ pip install -r $TE_PATH/examples/jax/encoder/requirements.txt pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist +# Make encoder tests to have run-to-run deterministic to have the stable CI results +export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py diff --git a/qa/L1_jax_distributed_unittest/test.sh b/qa/L1_jax_distributed_unittest/test.sh index 31206898b6..eb09df1a84 100644 --- a/qa/L1_jax_distributed_unittest/test.sh +++ b/qa/L1_jax_distributed_unittest/test.sh @@ -5,4 +5,11 @@ set -xe : ${TE_PATH:=/opt/transformerengine} -pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_* + +# Skip ring attention tests since they need fixed environment vars +pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_* -k 'not test_context_parallel_ring_attn' + +# Test ring attention with and without scan loop +NVTE_FUSED_RING_ATTENTION_USE_SCAN=0 pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_fused_attn.py -k test_context_parallel_ring_attn +NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 XLA_FLAGS="--xla_experimental_ignore_channel_id" \ + pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_fused_attn.py -k test_context_parallel_ring_attn diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index c22ba221be..9a11ccc008 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -10,4 +10,5 @@ pip install pytest==8.2.1 pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py +pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py diff --git a/qa/L1_pytorch_mcore_integration/test.sh b/qa/L1_pytorch_mcore_integration/test.sh new file mode 100644 index 0000000000..01c9e14eb1 --- /dev/null +++ b/qa/L1_pytorch_mcore_integration/test.sh @@ -0,0 +1,58 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +set -e + +# Paths +: ${TE_PATH:=/opt/transformerengine} +: ${MCORE_PATH:=${TE_PATH}/qa/L1_pytorch_mcore_integration/Megatron-LM} + +# Download Megatron-LM if needed +if [ ! -d "${MCORE_PATH}" ]; then + pushd $(dirname ${MCORE_PATH}) + git clone -b core_r0.9.0 https://github.com/NVIDIA/Megatron-LM.git Megatron-LM + popd +fi + +# Megatron-LM invocation +COMMAND=" +NVTE_TORCH_COMPILE=0 +NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 +NVTE_FLASH_ATTN=1 +NVTE_FWD_LAYERNORM_SM_MARGIN=0 +NVTE_BWD_LAYERNORM_SM_MARGIN=0 +CUDA_DEVICE_MAX_CONNECTIONS=1 +NVTE_BIAS_GELU_NVFUSION=0 +NVTE_BIAS_DROPOUT_FUSION=0 + +python +-m torch.distributed.launch +--use_env +--nnodes=1 +--nproc_per_node=1 + +${MCORE_PATH}/pretrain_gpt.py +--tensor-model-parallel-size 1 +--pipeline-model-parallel-size 1 +--use-cpu-initialization +--num-layers 2 +--hidden-size 128 +--num-attention-heads 8 +--seq-length 128 +--max-position-embeddings 2048 +--micro-batch-size 1 +--global-batch-size 8 +--train-iters 10 +--eval-iters 10 +--lr 1e-4 +--mock-data +--vocab-file /data/gpt3/pile-cc1-cc2-shuf/bpe/gpt2-vocab.json +--merge-file /data/gpt3/pile-cc1-cc2-shuf/bpe/gpt2-merges.txt +--transformer-impl transformer_engine +--fp8-format hybrid +" +COMMAND=$(echo "${COMMAND}" | tr '\n' ' ') + +# Launch Megatron-LM +bash -c "${COMMAND}" diff --git a/qa/L3_pytorch_FA_versions_test/test.sh b/qa/L3_pytorch_FA_versions_test/test.sh index 162ed85823..6c23e39a48 100644 --- a/qa/L3_pytorch_FA_versions_test/test.sh +++ b/qa/L3_pytorch_FA_versions_test/test.sh @@ -12,7 +12,7 @@ pip install pytest==8.2.1 export MAX_JOBS=4 # Iterate over Flash Attention versions -FA_versions=(2.1.1 2.3.0 2.4.0.post1 2.4.1 2.5.7 2.6.3 3.0.0b1) +FA_versions=(2.1.1 2.3.0 2.4.1 2.5.7 2.6.3 3.0.0b1) for fa_version in "${FA_versions[@]}" do diff --git a/setup.py b/setup.py index 512defa619..3bb2fe6b95 100644 --- a/setup.py +++ b/setup.py @@ -57,13 +57,20 @@ def run(self): def setup_common_extension() -> CMakeExtension: """Setup CMake extension for common library""" + cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(cuda_archs())] + if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): + assert ( + os.getenv("MPI_HOME") is not None + ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1" + cmake_flags.append("-DNVTE_UB_WITH_MPI=ON") + # Project directory root root_path = Path(__file__).resolve().parent return CMakeExtension( name="transformer_engine", cmake_path=root_path / Path("transformer_engine/common"), - cmake_flags=["-DCMAKE_CUDA_ARCHITECTURES={}".format(cuda_archs())], + cmake_flags=cmake_flags, ) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 6991d83d4c..20b16c2809 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -3,9 +3,8 @@ # See LICENSE for license information. from contextlib import nullcontext -import functools -import operator from typing import Callable, List, Sequence, Union +import os import jax import jax.numpy as jnp @@ -14,12 +13,18 @@ from jax import jit, value_and_grad from flax import linen as nn -from utils import assert_allclose +from utils import assert_allclose, assert_tree_like_allclose from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper, is_fp8_available from transformer_engine.jax.layernorm import layernorm, layernorm_fp8_dot from transformer_engine.jax.layernorm_mlp import activation_lu, fused_layernorm_fp8_mlp from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu +from transformer_engine.jax.cpp_extensions.transpose import ( + _jax_transpose, + _jax_cast_transpose, + _jax_dbias_cast_transpose, +) +from transformer_engine.jax.cpp_extensions.quantization import _jax_cast_fp8 from transformer_engine.jax import cpp_extensions as tex @@ -500,7 +505,6 @@ def _prim_func_bwd(ctx, g): scale_inv, FP8Helper.BWD_DTYPE, -1, - -2, self.activation_type, ) ) @@ -746,3 +750,130 @@ def ref_func(x, y, gamma, beta, zero_centered_gamma): assert_allclose(primitive_gamma_grad, ref_gamma_grad, dtype=FP8Helper.BWD_DTYPE) if beta is not None: assert_allclose(primitive_beta_grad, ref_beta_grad, dtype=FP8Helper.BWD_DTYPE) + + +@pytest.mark.parametrize( + "in_dtype", + [ + pytest.param(jnp.float32, id="input_float32"), + pytest.param(jnp.float16, id="input_float16"), + pytest.param(jnp.bfloat16, id="input_bfloat16"), + ], +) +@pytest.mark.parametrize( + "input_shape, transpose_axis", + [ + pytest.param((16, 16), 1, id="(16, 16)-1"), + pytest.param((256, 128), 1, id="(256, 128)-1"), + pytest.param((128, 512), 1, id="(128, 512)-1"), + pytest.param((64, 16, 4, 256), 1, id="(64, 16, 4, 256)-1"), + pytest.param((64, 16, 4, 256), 2, id="(64, 16, 4, 256)-2"), + pytest.param((64, 16, 4, 256), 3, id="(64, 16, 4, 256)-3"), + ], +) +class TestTranspose: + def test_transpose(self, in_dtype, input_shape, transpose_axis): + key = jax.random.PRNGKey(0) + input_tensor = jax.random.uniform(key, input_shape, in_dtype) + static_axis_boundary = -1 + jax_output = _jax_transpose(input_tensor, static_axis_boundary, transpose_axis) + os.environ["NVTE_JAX_WITH_FFI"] = "0" + noffi_output = tex.transpose(input_tensor, static_axis_boundary, transpose_axis) + os.environ["NVTE_JAX_WITH_FFI"] = "1" + ffi_output = tex.transpose(input_tensor, static_axis_boundary, transpose_axis) + assert_allclose(jax_output, noffi_output) + assert_allclose(noffi_output, ffi_output) + + @pytest.mark.parametrize( + "out_dtype", + [ + pytest.param(jnp.float8_e4m3fn, id="output_float8_e4m3fn"), + pytest.param(jnp.float8_e5m2, id="output_float8_e5m2"), + ], + ) + def test_cast_transpose(self, in_dtype, input_shape, transpose_axis, out_dtype): + amax = jnp.zeros(1, jnp.float32) + scale = jnp.ones(1, jnp.float32) + scale_inv = jnp.ones(1, jnp.float32) + key = jax.random.PRNGKey(0) + input = jax.random.uniform(key, input_shape, in_dtype) + static_axis_boundary = -1 + jax_output = _jax_cast_transpose( + input, scale, amax, out_dtype, static_axis_boundary, transpose_axis + ) + os.environ["NVTE_JAX_WITH_FFI"] = "0" + noffi_output = tex.cast_transpose( + input, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis + ) + os.environ["NVTE_JAX_WITH_FFI"] = "1" + ffi_output = tex.cast_transpose( + input, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis + ) + assert_tree_like_allclose(jax_output, ffi_output) + assert_tree_like_allclose(noffi_output, ffi_output) + + @pytest.mark.parametrize( + "out_dtype", + [ + pytest.param(jnp.float8_e4m3fn, id="output_float8_e4m3fn"), + pytest.param(jnp.float8_e5m2, id="output_float8_e5m2"), + ], + ) + def test_dbias_cast_transpose(self, in_dtype, input_shape, transpose_axis, out_dtype): + amax = jnp.zeros(1, jnp.float32) + scale = jnp.ones(1, jnp.float32) + scale_inv = jnp.ones(1, jnp.float32) + key = jax.random.PRNGKey(0) + input = jax.random.uniform(key, input_shape, in_dtype) + static_axis_boundary = -1 + jax_output = _jax_dbias_cast_transpose( + input, amax, scale, out_dtype, static_axis_boundary, transpose_axis + ) + os.environ["NVTE_JAX_WITH_FFI"] = "0" + noffi_output = tex.dbias_cast_transpose( + input, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis + ) + os.environ["NVTE_JAX_WITH_FFI"] = "1" + ffi_output = tex.dbias_cast_transpose( + input, amax, scale, scale_inv, out_dtype, static_axis_boundary, transpose_axis + ) + assert_tree_like_allclose(jax_output, ffi_output) + assert_tree_like_allclose(noffi_output, ffi_output) + + +@pytest.mark.skipif(not is_fp8_supported, reason=reason) +@pytest.mark.parametrize( + "input_shape", + [ + pytest.param((256, 128), id="(256, 128)"), + pytest.param((128, 512, 8), id="(128, 512, 8)"), + ], +) +@pytest.mark.parametrize( + "in_dtype", + [ + pytest.param(jnp.float32, id="input_float32"), + pytest.param(jnp.float16, id="input_float16"), + pytest.param(jnp.bfloat16, id="input_bfloat16"), + ], +) +@pytest.mark.parametrize( + "out_dtype", + [ + pytest.param(jnp.float8_e4m3fn, id="output_float8_e4m3fn"), + pytest.param(jnp.float8_e5m2, id="output_float8_e5m2"), + ], +) +def test_quantize(input_shape, in_dtype, out_dtype): + amax = jnp.zeros(1, jnp.float32) + scale = jnp.ones(1, jnp.float32) + scale_inv = jnp.ones(1, jnp.float32) + key = jax.random.PRNGKey(0) + input = jax.random.uniform(key, input_shape, in_dtype) + jax_output = _jax_cast_fp8(input, scale, amax, out_dtype) + os.environ["NVTE_JAX_WITH_FFI"] = "0" + noffi_output = tex.cast_fp8(input, amax, scale, scale_inv, out_dtype) + os.environ["NVTE_JAX_WITH_FFI"] = "1" + ffi_output = tex.cast_fp8(input, amax, scale, scale_inv, out_dtype) + assert_tree_like_allclose(jax_output, ffi_output) + assert_tree_like_allclose(noffi_output, ffi_output) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 23a26087d4..7ef0d68474 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -35,7 +35,9 @@ get_qkv_format, reorder_causal_load_balancing, inverse_reorder_causal_load_balancing, + CPStrategy, ) +from transformer_engine.jax.sharding import MeshResource # We will use the golden reference model from our non distributed attention test fixture. from test_fused_attn import general_dot_product_attention, make_mask @@ -133,7 +135,6 @@ def test_self_attn( seqlen, hidden, None, # no window - False, # not context parallel ): pytest.skip(f"No FusedAttn backend found") @@ -268,7 +269,6 @@ def test_cross_attn( seqlen, hidden, None, # no window - False, # not context parallel ): pytest.skip(f"No FusedAttn backend found") @@ -335,6 +335,36 @@ def ref_func(query, kv, mask): ) +@pytest.mark.parametrize( + "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs() +) +@pytest.mark.parametrize( + "data_shape", + [ + pytest.param([2, 512, 12, 128], id="2-512-12-128"), + pytest.param([4, 1024, 16, 64], id="4-1024-16-64"), + ], +) +@pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16]) +@pytest.mark.parametrize( + "attn_mask_type", + [ + pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"), + pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"), + ], +) +@pytest.mark.parametrize("dtype", [jnp.bfloat16]) +@pytest.mark.parametrize( + "qkv_layout", + [ + pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"), + pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"), + ], +) +@pytest.mark.parametrize( + "load_balanced", + [pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")], +) class TestDistributedContextParallelSelfAttn: def generate_inputs(self, shape, kv_groups: int, attn_mask_type: AttnMaskType, dtype): @@ -372,37 +402,7 @@ def qkv_to_layout(self, q, k, v, qkv_layout): raise ValueError(f"Unsupported {qkv_layout=}") return qkv_args - @pytest.mark.parametrize( - "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs() - ) - @pytest.mark.parametrize( - "data_shape", - [ - pytest.param([2, 512, 12, 128], id="2-512-12-128"), - pytest.param([4, 1024, 16, 64], id="4-1024-16-64"), - ], - ) - @pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16]) - @pytest.mark.parametrize( - "attn_mask_type", - [ - pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"), - pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"), - ], - ) - @pytest.mark.parametrize("dtype", [jnp.bfloat16]) - @pytest.mark.parametrize( - "qkv_layout", - [ - pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"), - pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"), - ], - ) - @pytest.mark.parametrize( - "load_balanced", - [pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")], - ) - def test_contex_parallel_self_attn( + def impl_test_contex_parallel_attn( self, device_count, mesh_shape, @@ -414,6 +414,7 @@ def test_contex_parallel_self_attn( dtype, qkv_layout, load_balanced, + cp_strategy, ): attn_bias_type = AttnBiasType.NO_BIAS dropout_prob = 0.0 @@ -425,22 +426,32 @@ def test_contex_parallel_self_attn( num_kv_heads = num_head // kv_groups scaling_factor = 1.0 / np.sqrt(num_head) - if not is_fused_attn_kernel_available( - dtype, - dtype, - qkv_layout, - attn_bias_type, - attn_mask_type, - dropout_prob, - num_head, - num_kv_heads, - seqlen, - seqlen, - hidden, - None, # no window - cp_size > 1, - ): - pytest.skip(f"No FusedAttn backend found") + def check_has_backend_for_mask(mask_type): + return is_fused_attn_kernel_available( + dtype, + dtype, + qkv_layout, + attn_bias_type, + attn_mask_type, + dropout_prob, + num_head, + num_kv_heads, + seqlen, + seqlen, + hidden, + None, + ) # no SWA for CP + + # For causal masking we depend on having bottom right support also. + # The API does not check this and instead we rely on lower level checks to raise + # and exception if the step backend is not supported. This was a deliberate API + # decision to keep the CP size or flag out of the function. + has_backend = check_has_backend_for_mask(attn_mask_type) + if cp_size > 1 and attn_mask_type == AttnMaskType.CAUSAL_MASK: + has_backend &= check_has_backend_for_mask(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK) + + if not has_backend: + pytest.skip(f"No FusedAttn backend found {cp_size=} {attn_mask_type=}.") if dp_size > 1 and batch % dp_size != 0: pytest.skip(f"Skipping {batch=} not a multiple of {dp_size=}") @@ -461,6 +472,7 @@ def target_func(q, k, v, mask): scaling_factor=scaling_factor, dropout_probability=dropout_prob, is_training=is_training, + context_parallel_strategy=cp_strategy, context_parallel_causal_load_balanced=load_balanced, context_parallel_axis="cp", ).astype(dtype) @@ -566,6 +578,60 @@ def grad_func(func, *args, **kwargs): assert_allclose(target_grads[i], ref_grads[i], dtype=dtype) + def test_contex_parallel_allgather_attn( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + kv_groups, + attn_mask_type, + dtype, + qkv_layout, + load_balanced, + ): + return self.impl_test_contex_parallel_attn( + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + kv_groups, + attn_mask_type, + dtype, + qkv_layout, + load_balanced, + CPStrategy.ALL_GATHER, + ) + + def test_context_parallel_ring_attn( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + kv_groups, + attn_mask_type, + dtype, + qkv_layout, + load_balanced, + ): + return self.impl_test_contex_parallel_attn( + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + data_shape, + kv_groups, + attn_mask_type, + dtype, + qkv_layout, + load_balanced, + CPStrategy.RING, + ) + class TestReorderCausalLoadBalancing: @pytest.mark.parametrize("cp_size", [2, 4, 8]) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index d4f92e940d..af05538ef5 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -7,6 +7,7 @@ from functools import partial from math import sqrt from typing import Tuple, Optional +import random import jax import jax.numpy as jnp @@ -305,11 +306,14 @@ def _check_configs(self): ]: pytest.skip("THD format requires padding masks.") - if self.qkv_layout == QKVLayout.BS3HD or get_qkv_format(self.qkv_layout) == QKVFormat.THD: - if self.num_heads_q != self.num_heads_kv: - pytest.skip("QKVPACKED layout requires num_heads_q and num_heads_kv to be equal.") + qkv_format = get_qkv_format(self.qkv_layout) + if self.qkv_layout == QKVLayout.BS3HD or qkv_format == QKVFormat.THD: if self.max_seqlen_q != self.max_seqlen_kv: - pytest.skip("QKVPACKED layout requires max_seqlen_q and max_seqlen_kv to be equal.") + pytest.skip(f"{self.qkv_layout} requires max_seqlen_q == max_seqlen_kv") + + if self.qkv_layout == QKVLayout.BS3HD or self.qkv_layout == QKVLayout.T3HD: + if self.num_heads_q != self.num_heads_kv: + pytest.skip(f"{self.qkv_layout} requires num_heads_q == num_heads_kv") if self.max_seqlen_q > self.max_seqlen_kv and self.window_size is not None: pytest.skip( diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index 3245bca676..55c09b4562 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -67,7 +67,7 @@ def enable_fused_attn(): _KEY_OF_TRANSPOSE_BS: True, _KEY_OF_NUM_HEADS: 8, _KEY_OF_HIDDEN_DROPOUT: 0, - _KEY_OF_ATTENTION_DROPOUT: 0, + _KEY_OF_ATTENTION_DROPOUT: 0.0, _KEY_OF_INTERMEDIATE_DROPOUT: 0, _KEY_OF_SELF_ATTN_MASK_TYPE: "padding_causal", _KEY_OF_LAYERNORM_TYPE: "layernorm", diff --git a/tests/jax/test_misc.py b/tests/jax/test_misc.py new file mode 100644 index 0000000000..67145daf63 --- /dev/null +++ b/tests/jax/test_misc.py @@ -0,0 +1,40 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +from functools import partial +import os + +from transformer_engine.jax.cpp_extensions.misc import get_xla_flag + + +@pytest.fixture(autouse=True, scope="function") +def preserve_xla_flags(): + """Ensures the XLA flags environment variable is restored after any tests in this file run.""" + old_flags = os.getenv("XLA_FLAGS") + yield + if old_flags is not None: + os.environ["XLA_FLAGS"] = old_flags + + +def test_get_xla_flag(request): + os.environ["XLA_FLAGS"] = "" + assert get_xla_flag("") is None + assert get_xla_flag("--foo") is None + assert get_xla_flag("--bar=1") is None + + os.environ["XLA_FLAGS"] = "--foo --bar=1 --baz=biz" + assert get_xla_flag("--foo") == True + assert get_xla_flag("--bar") == "1" + assert get_xla_flag("--bar", cast=int) == 1 + assert get_xla_flag("--bar", cast=bool) == True + assert get_xla_flag("--baz") == "biz" + with pytest.raises(ValueError): + # cast will fail + assert get_xla_flag("--baz", cast=int) + assert get_xla_flag("--xla") is None + + os.environ["XLA_FLAGS"] = "--xla_abc --xla_abb" + assert get_xla_flag("--xla_abc") == True + assert get_xla_flag("--xla_abb") == True diff --git a/tests/pytorch/distributed/run_gemm_with_overlap.py b/tests/pytorch/distributed/run_gemm_with_overlap.py index 5ba70ccbdd..b00b8cc042 100644 --- a/tests/pytorch/distributed/run_gemm_with_overlap.py +++ b/tests/pytorch/distributed/run_gemm_with_overlap.py @@ -22,6 +22,7 @@ from transformer_engine.common.recipe import Format from transformer_engine.pytorch.fp8 import _default_sf_compute +warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=UserWarning) @@ -32,8 +33,8 @@ } nvte_comm_types = { - "rs": 0, - "ag": 1, + "rs": tex.CommOverlapType.RS, + "ag": tex.CommOverlapType.AG, } @@ -75,7 +76,7 @@ def _parse_args(argv=None, namespace=None): parser.add_argument( "--comm-type", type=partial(_mapped_argtype, typemap=nvte_comm_types), - default=0, + default=tex.CommOverlapType.AG, help="Comm type to overlap.", ) parser.add_argument( @@ -156,12 +157,10 @@ def _parse_args(argv=None, namespace=None): if opts.fp8: warnings.warn("Bulk overlap is supported in FP8 but only tested in BF16.") opts.fp8 = False - elif opts.comm_type == 1: + elif opts.comm_type == tex.CommOverlapType.AG: if opts.atomic: setattr(opts, "atomic_rs_p2p", opts.p2p) - if not opts.p2p: - warnings.warn("All-gather overlap is only supported with point-2-point comms.") - opts.p2p = True + opts.p2p = True if opts.atomic: if not te.fp8.check_fp8_support(): @@ -283,35 +282,35 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None if WORLD_RANK == 0: print("\n", end="", flush=True) - ub_callbacks = ( - tex.UbufBootstrapCallbacks() + helper = ( + tex.CommOverlapHelper() if tex.ubuf_built_with_mpi() - else tex.UbufBootstrapCallbacks(bootstrap_pg, bootstrap_pg) + else tex.CommOverlapHelper(bootstrap_pg) ) - if opts.comm_type == 0: + if opts.comm_type == tex.CommOverlapType.RS: if opts.bulk_overlap: - ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_RS + ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_RS elif opts.p2p: ub_algo = ( - tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P if opts.atomic - else tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + else tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P ) else: ub_algo = ( - tex.UbufOverlapAlgo.ATOMIC_GEMM_RS + tex.CommOverlapAlgo.ATOMIC_GEMM_RS if opts.atomic - else tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + else tex.CommOverlapAlgo.SPLIT_PIPELINED_RS ) - elif opts.comm_type == 1: + elif opts.comm_type == tex.CommOverlapType.AG: if opts.bulk_overlap: - ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG + ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_AG else: ub_algo = ( - tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P + tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P if opts.atomic - else tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P + else tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P ) else: raise TypeError("Invalid comm+GEMM overlap type!") @@ -322,95 +321,55 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None hidden_size = opts.num_heads * opts.head_dim inp_shape = (opts.seq_length, opts.batch_size, hidden_size) outer_size = reduce(operator.mul, inp_shape[:-1], 1) - ubuf_dtype = torch.bfloat16 - if opts.fp8 and not opts.bulk_overlap and (opts.comm_type == 1 or opts.fp8_output): - ubuf_dtype = torch.uint8 - sample_buffer = torch.empty((outer_size, hidden_size), dtype=ubuf_dtype, device="cuda") - ub_obj = ub_obj = ( - tex.UbufP2PCommOverlap( - sample_buffer, # Sample userbuffer - WORLD_RANK, # World rank - WORLD_SIZE, # World size - LOCAL_RANK, # Rank within the node - LOCAL_SIZE, # Number of ranks/GPUs per node - 0, # Node ID - 1, # Number of nodes + buffer_dtype = torch.bfloat16 + if ( + opts.fp8 + and not opts.bulk_overlap + and (opts.comm_type == tex.CommOverlapType.AG or opts.fp8_output) + ): + buffer_dtype = torch.uint8 + ub_obj = ( + tex.CommOverlapP2P( + (outer_size, hidden_size), + buffer_dtype, + helper, tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) - 1, # Number of communication SMs - 1, # CGA cluster size - opts.comm_type == 0 or opts.atomic, # Set SM margin - opts.aggregate, # Aggregate 2X GEMM chunks - 3, # Max concurrent GEMM streams - opts.comm_type == 0, # overlap with reduce scatter - opts.atomic, # use a single GEMM with atomic-counters - not (opts.atomic and bool(int(os.getenv("NVTE_AG_P2P_MULTI_ATOMIC", "0")))), - ub_callbacks, + opts.comm_type, + set_sm_margin=opts.comm_type == tex.CommOverlapType.RS or opts.atomic, + atomic_gemm=opts.atomic, + aggregate=opts.aggregate, + use_ce=not (opts.atomic and bool(int(os.getenv("NVTE_AG_P2P_MULTI_ATOMIC", "0")))), ) if opts.p2p - else tex.UbufCommOverlap( - sample_buffer, # Sample userbuffer - WORLD_RANK, # World rank - WORLD_SIZE, # World size - LOCAL_RANK, # Rank within the node - LOCAL_SIZE, # Number of ranks/GPUs per node - 0, # Node ID - 1, # Number of nodes + else tex.CommOverlap( + (outer_size, hidden_size), + buffer_dtype, + helper, tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) - 16, # Number of communication SMs - 2, # CGA cluster size - 4, # Number of communication splits - True, # Set SM margin - 3, # Max concurrent GEMM streams - opts.atomic, # Use a single GEMM with atomic-counters - ub_callbacks, + atomic_gemm=opts.atomic, ) ) # Numerical check on AG + atomic GEMM requires testing an AG+RS pair ub_obj2 = None - if opts.atomic and opts.comm_type == 1 and opts.check_numerics: - sample_buffer2 = torch.empty( - (outer_size, hidden_size), - dtype=torch.uint8 if opts.fp8_output else torch.bfloat16, - device="cuda", - ) + if opts.atomic and opts.comm_type == tex.CommOverlapType.AG and opts.check_numerics: ub_obj2 = ( - tex.UbufP2PCommOverlap( - sample_buffer2, # Sample userbuffer - WORLD_RANK, # World rank - WORLD_SIZE, # World size - LOCAL_RANK, # Rank within the node - LOCAL_SIZE, # Number of ranks/GPUs per node - 0, # Node ID - 1, # Number of nodes + tex.CommOverlapP2P( + (outer_size, hidden_size), + torch.uint8 if opts.fp8_output else torch.bfloat16, + helper, tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) - 1, # Number of communication SMs - 1, # CGA cluster size - True, # Set SM margin - False, # Aggregate 2X GEMM chunks - 3, # Max concurrent GEMM streams - True, # overlap with reduce scatter - True, # use a single GEMM with atomic-counters - True, # use copy engine for P2P communications - ub_callbacks, + tex.CommOverlapType.RS, + set_sm_margin=True, + atomic_gemm=True, ) if opts.atomic_rs_p2p - else tex.UbufCommOverlap( - sample_buffer2, # Sample userbuffer - WORLD_RANK, # World rank - WORLD_SIZE, # World size - LOCAL_RANK, # Rank within the node - LOCAL_SIZE, # Number of ranks/GPUs per node - 0, # Node ID - 1, # Number of nodes + else tex.CommOverlap( + (outer_size, hidden_size), + torch.uint8 if opts.fp8_output else torch.bfloat16, + helper, tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) - 16, # Number of communication SMs - 2, # CGA cluster size - 4, # Number of communication splits - True, # Set SM margin - 3, # Max concurrent GEMM streams - True, # uUe a single GEMM with atomic-counters - ub_callbacks, + atomic_gemm=True, ) ) @@ -426,12 +385,12 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None local_kernel_t_shape = (ffn_hidden_size, hidden_size) local_inp_shape = (outer_size, hidden_size) # Bulk overlap comm tensor is distributed for AG overlap only - if opts.comm_type == 1: + if opts.comm_type == tex.CommOverlapType.AG: bulk_inp_shape = (outer_size // tp_size, hidden_size) else: bulk_inp_shape = (outer_size, hidden_size) else: - if opts.comm_type == 1: + if opts.comm_type == tex.CommOverlapType.AG: # (M/P, N) -> overlapped AG -> (M, N) x (K/P, N)^T = (M, K/P) local_kernel_t_shape = (ffn_hidden_size // tp_size, hidden_size) local_inp_shape = (outer_size // tp_size, hidden_size) @@ -472,7 +431,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None std=opts.std, ) else: - if opts.comm_type == 1: + if opts.comm_type == tex.CommOverlapType.AG: # AG Kernel: (K/P, N) -> gather -> (K, N) -> T -> (N, K) ker_g = torch.transpose( te.distributed.gather_along_first_dim(kernel_t, tp_group)[0], 0, 1 @@ -494,7 +453,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None ).to(dtype=torch.float32) if opts.bulk_overlap: - if opts.comm_type == 1: + if opts.comm_type == tex.CommOverlapType.AG: ref_g = te.distributed.gather_along_first_dim(bulk_inp, tp_group)[0] else: # First all-gather all the bulk inputs into a list @@ -505,7 +464,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None else: ref_g = torch.matmul(inp_g, ker_g) if ub_obj2 is not None: - inp2_g = torch.nn.functional.gelu(ref_g) + inp2_g = torch.nn.functional.gelu(ref_g) # pylint: disable=not-callable ref2_g = torch.matmul(inp2_g, ker2_g) if opts.fp8: @@ -529,7 +488,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_WEIGHT].copy_(ker_amax) ref_amax = torch.max(torch.abs(ref_g)) fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_OUTPUT].copy_(ref_amax) - if opts.bulk_overlap and opts.comm_type == 0: + if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS: bulk_amax = torch.max(torch.abs(bulk_inp)) fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_OUTPUT].copy_(bulk_amax) elif ub_obj2 is not None: @@ -551,7 +510,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None kernel_t_fp8 = tex.cast_to_fp8( kernel_t, fp8_meta, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype ) - if opts.bulk_overlap and opts.comm_type == 0: + if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS: bulk_inp_fp8 = tex.cast_to_fp8( bulk_inp, fp8_meta, tex.FP8Tensors.GEMM2_OUTPUT, fp8_dtype ) @@ -574,7 +533,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None rtol=0.125, atol=0.0675, ) - if opts.bulk_overlap and opts.comm_type == 0: + if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS: torch.allclose( bulk_inp.to(dtype=torch.float32), bulk_inp_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT], @@ -590,7 +549,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None ) # Set Fp8 scales for userbuffers - if opts.comm_type == 1: + if opts.comm_type == tex.CommOverlapType.AG: ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_INPUT]) if ub_obj2 is not None: ub_obj2.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT]) @@ -602,7 +561,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None # Set up comm/compute buffers ubuf_out2 = None rs_out2 = None - if opts.comm_type == 1: + if opts.comm_type == tex.CommOverlapType.AG: if opts.bulk_overlap: ub_obj.copy_input_to_ubuf(bulk_inp, 1) gemm_inp = inp @@ -686,9 +645,9 @@ def _fp8_gemm2(gemm1_out): gelu=False, use_split_accumulator=te.module.base._2X_ACC_FPROP, ub_algo=( - tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P if opts.atomic_rs_p2p - else tex.UbufOverlapAlgo.ATOMIC_GEMM_RS + else tex.CommOverlapAlgo.ATOMIC_GEMM_RS ), ub=ub_obj2, extra_output_tensor=rs_out2, @@ -762,10 +721,14 @@ def _gemm(): avg_gpu_time = sum(gpu_times) / opts.timing_iters gemm_name = "".join( [ - "p2p all-gather + " if opts.comm_type == 1 else "", + "p2p all-gather + " if opts.comm_type == tex.CommOverlapType.AG else "", "atomic " if opts.atomic else "", "GEMM", - (f" + {'p2p ' if opts.p2p else ''}reduce-scatter" if opts.comm_type == 0 else ""), + ( + f" + {'p2p ' if opts.p2p else ''}reduce-scatter" + if opts.comm_type == tex.CommOverlapType.RS + else "" + ), ] ) timing_info = ( @@ -781,7 +744,7 @@ def _gemm(): dist.barrier(tp_group) if opts.bulk_overlap: output_info = "" - if opts.comm_type == 1: + if opts.comm_type == tex.CommOverlapType.AG: # Bulk overlap AG output is already gathered test_out = ub_obj.get_ubuf_output(1) else: @@ -794,7 +757,7 @@ def _gemm(): output_info += f"output: {list(test_out.shape)} | reference: {list(ref_out.shape)}" dist_print( output_info, - src=0 if opts.comm_type == 0 else None, + src=0 if opts.comm_type == tex.CommOverlapType.RS else None, section=True, ) @@ -805,7 +768,7 @@ def _gemm(): ) dist_print(nonzero_info, src=0, section=True, group=tp_group) else: - if opts.comm_type == 1: + if opts.comm_type == tex.CommOverlapType.AG: if ub_obj2 is not None: # AG+RS Output: (M/P, N) -> gather -> (M, N) output = rs_out2.to(dtype=torch.float32) diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index e5653bda01..e32a7ccb12 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -9,7 +9,6 @@ import socket import argparse import warnings -from functools import partial import torch import torch.distributed as dist diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index 63310195ae..ce46a72189 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -42,6 +42,9 @@ # Force GPU kernels to launch in the order they're executed by the host CPU os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" +# Clear torch.dynamo caches +torch._dynamo.reset() + def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggregate): test_path = TEST_ROOT / "run_gemm_with_overlap.py" diff --git a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py new file mode 100644 index 0000000000..ead121f314 --- /dev/null +++ b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py @@ -0,0 +1,495 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from __future__ import annotations + +import argparse +import dataclasses +import functools +import itertools +import os +import pathlib +import subprocess +import sys + +import pytest +import torch + +import transformer_engine +import transformer_engine.pytorch as te +import transformer_engine.pytorch.cpp_extensions as tex +from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +import transformer_engine.pytorch.ops as te_ops +from transformer_engine.pytorch.ops._common import is_float8_tensor +from transformer_engine.pytorch.ops.fused import ( + UserbuffersBackwardLinear, + UserbuffersForwardLinear, +) +from transformer_engine.pytorch.utils import is_bf16_compatible + +# Import utility functions +_current_file = pathlib.Path(__file__).resolve() +sys.path.append(str(_current_file.parent.parent)) +from utils import dtype_tols, str_to_dtype + +# Check if FP8 is supported +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + +# Check if there are multiple GPUs +if torch.cuda.device_count() < 2: + pytest.skip("Userbuffers requires at least 2 GPUs.") + + +@dataclasses.dataclass +class ModelConfig: + """Tensor dimensions in Transformer model""" + + sequence_length: int + batch_size: int + num_heads: int + head_dim: int + dtype: torch.dtype + fp8: bool + + @property + def hidden_size(self): + return self.num_heads * self.head_dim + + +@functools.cache +def launcher() -> str: + """Launcher for current parallel job""" + if "OMPI_COMM_WORLD_SIZE" in os.environ: + return "ompi" + if "TORCHELASTIC_RUN_ID" in os.environ: + return "torchrun" + raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`") + + +@functools.cache +def world_group() -> torch.distributed.ProcessGroup: + """Get NCCL process group, initializing if needed""" + + # Get launch config from environment + if launcher() == "ompi": + # OpenMPI + world_size = int(os.getenv("OMPI_COMM_WORLD_SIZE")) + rank = int(os.getenv("OMPI_COMM_WORLD_RANK")) + local_size = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE")) + local_rank = int(os.getenv("OMPI_COMM_WORLD_LOCAL_RANK")) + elif launcher() == "torchrun": + # torchrun + world_size = int(os.getenv("WORLD_SIZE")) + rank = int(os.getenv("RANK")) + local_size = int(os.getenv("LOCAL_WORLD_SIZE")) + local_rank = int(os.getenv("LOCAL_RANK")) + else: + raise RuntimeError("Unexpected launcher ({launcher()})") + + # Construct communicator + assert local_size == world_size + torch.cuda.set_device(local_rank) + group = torch.distributed.init_process_group( + "nccl", + init_method="file:///tmp/rdzv", + world_size=world_size, + rank=rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) + return group + + +def reset_rng(seed: int = 1234) -> None: + """Reset random number generators""" + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + +@torch.no_grad() +def make_reference_and_test_tensors( + shape: int | Iterable[int], + ref_dtype: torch.dtype = torch.float64, + ref_device: torch.device = "cpu", + test_dtype: torch.dtype = torch.float32, + test_device: torch.device = "cuda", + test_is_fp8: bool = False, + requires_grad: bool = True, +) -> tuple[torch.Tensor, torch.Tensor]: + """Construct tensors with the same values + + The reference tensor is intended for use in plain PyTorch + operations in high precision. The test tensor is intended for use + in Transformer Engine operations. + + """ + + # Random data + ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) + + # Make copy of tensor + if test_is_fp8: + test = Float8Tensor.to_float8(ref) + else: + test = ref.to(device=test_device, dtype=test_dtype) + if test.data_ptr() == ref.data_ptr(): + test = test.clone() + + # Make sure reference and test tensors represent exact same values + ref.copy_(test) + + # Return reference and test tensors + ref.requires_grad_(requires_grad) + test.requires_grad_(requires_grad) + return ref, test + + +def _test_linear( + *, + model_config: ModelConfig, + bias: bool = False, + device: torch.device = "cuda", + tensor_parallel_mode: str = "column", + sequence_parallel: bool = True, + weight_requires_grad: bool = True, +) -> None: + dtype = model_config.dtype + fp8_compute = model_config.fp8 + + # Distributed process group + process_group = world_group() + rank = torch.distributed.get_rank(process_group) + world_size = torch.distributed.get_world_size(process_group) + + # Tensor dimensions + out_features = model_config.hidden_size + in_features = model_config.hidden_size + batch_size = model_config.sequence_length * model_config.batch_size + in_shape = [batch_size, in_features] + out_shape = [batch_size, out_features] + + # Random data + reset_rng() + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8_compute, + ) + w_ref, w_test = make_reference_and_test_tensors( + (out_features, in_features), + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8_compute, + ) + b_ref, b_test = None, None + if bias: + if tensor_parallel_mode == "row": + bias_shape = [world_size, out_features] + else: + bias_shape = [out_features] + b_ref, b_test = make_reference_and_test_tensors( + bias_shape, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8_compute, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = torch.nn.functional.linear(x_ref, w_ref) + if bias: + if tensor_parallel_mode == "row": + y_ref += b_ref.sum(dim=0) + else: + y_ref += b_ref + y_ref.backward(dy_ref) + + # Convert to distributed tensors + with torch.no_grad(): + dw_ref = w_ref.grad + db_ref = b_ref.grad if bias else None + dx_ref = x_ref.grad + if tensor_parallel_mode == "column": + local_out_features = out_features // world_size + local_slice = slice( + rank * local_out_features, + (rank + 1) * local_out_features, + ) + w_ref = w_ref[local_slice, :] + dw_ref = dw_ref[local_slice, :] + w_test = w_test[local_slice, :] + if bias: + b_ref = b_ref[local_slice] + db_ref = db_ref[local_slice] + b_test = b_test[local_slice] + y_ref = y_ref[..., local_slice] + dy_ref = dy_ref[..., local_slice] + dy_test = dy_test[..., local_slice].clone() + elif tensor_parallel_mode == "row": + local_in_features = in_features // world_size + local_slice = slice( + rank * local_in_features, + (rank + 1) * local_in_features, + ) + w_ref = w_ref[:, local_slice] + dw_ref = dw_ref[:, local_slice] + w_test = w_test[:, local_slice] + if bias: + b_ref = b_ref[rank, :] + db_ref = db_ref[rank, :] + b_test = b_test[rank, :] + x_ref = x_ref[..., local_slice] + dx_ref = dx_ref[..., local_slice] + x_test = x_test[..., local_slice].clone() + if sequence_parallel: + local_batch_size = batch_size // world_size + local_slice = slice( + rank * local_batch_size, + (rank + 1) * local_batch_size, + ) + if tensor_parallel_mode == "column": + x_ref = x_ref[local_slice, ...] + dx_ref = dx_ref[local_slice, ...] + x_test = x_test[local_slice, ...].clone() + elif tensor_parallel_mode == "row": + y_ref = y_ref[local_slice, ...] + dy_ref = dy_ref[local_slice, ...] + dy_test = dy_test[local_slice, ...].clone() + x_test.requires_grad_() + + # Implementation with fusible operation + with te.fp8_model_init(enabled=fp8_compute): + ops = [] + linear_op = None + bias_op = None + if tensor_parallel_mode == "column": + userbuffers_options = {} + if not weight_requires_grad: + if fp8_compute: + userbuffers_options["comm_name"] = "fc1" + else: + # There is a correctness bug with overlapping + # dgrad reduce-scatter with dgrad GEMM. Fall back + # to overlapping dgrad reduce-scatter with wgrad + # GEMM, even though wgrad isn't needed. + userbuffers_options["comm_name"] = "qkv" + else: + userbuffers_options["comm_name"] = "qkv" + linear_op = te_ops.BasicLinear( + in_features, + out_features, + device=device, + dtype=dtype, + tensor_parallel_mode=tensor_parallel_mode, + tensor_parallel_group=process_group, + sequence_parallel=sequence_parallel, + userbuffers_options=userbuffers_options, + ) + ops.append(linear_op) + if bias: + bias_op = te_ops.Bias( + out_features // world_size, + device=device, + dtype=dtype, + ) + ops.append(bias_op) + elif tensor_parallel_mode == "row": + userbuffers_options = dict(comm_name="proj") + linear_op = te_ops.BasicLinear( + in_features // world_size, + out_features, + device=device, + dtype=dtype, + userbuffers_options=userbuffers_options, + ) + ops.append(linear_op) + if bias: + bias_op = te_ops.Bias(out_features, device=device, dtype=dtype) + ops.append(bias_op) + ops.append(te_ops.ReduceScatter(process_group)) + model = te_ops.Sequential(*ops) + with torch.no_grad(): + linear_op.weight.copy_(w_test) + linear_op.weight.requires_grad_(requires_grad=weight_requires_grad) + if bias: + bias_op.bias.copy_(b_test) + del w_test + del b_test + with te.fp8_autocast(enabled=fp8_compute): + y_test = model(x_test) + y_test.backward(dy_test) + + # Check that forward operations have been fused + forward_ops = model._module_groups[0]._forward_ops + backward_ops = model._module_groups[0]._backward_ops + assert len(forward_ops) == 1 + assert len(backward_ops) == 1 + assert isinstance(forward_ops[0][0], UserbuffersForwardLinear) + assert isinstance(backward_ops[0][0], UserbuffersBackwardLinear) + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + if fp8_compute: + tols = dtype_tols( + model[0].weight._fp8_dtype + if is_float8_tensor(model[0].weight) + else tex.DType.kFloat8E4M3 + ) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, dx_ref, **tols) + if weight_requires_grad: + dw_test = linear_op.weight.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(dw_test, dw_ref, **tols) + if bias: + db_test = bias_op.bias.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(db_test, db_ref, **tols) + + +def run_parallel_tests(model_config: ModelConfig) -> None: + """Run parallel tests""" + + # Distributed process group + process_group = world_group() + rank = torch.distributed.get_rank(process_group) + world_size = torch.distributed.get_world_size(process_group) + + # Linear op + for test_config in itertools.product( + (False, True), # bias + ("column", "row"), # tensor_parallel_mode + (False, True), # weight_requires_grad + ): + if rank == 0: + print(f"Running _test_linear with {test_config=}") + bias, tensor_parallel_mode, weight_requires_grad = test_config + _test_linear( + model_config=model_config, + bias=bias, + tensor_parallel_mode=tensor_parallel_mode, + weight_requires_grad=weight_requires_grad, + ) + + +# Parallel job sizes +_world_sizes = [] +if torch.cuda.device_count() > 1: + _world_sizes.append(torch.cuda.device_count()) + + +@pytest.mark.parametrize("world_size", _world_sizes) +@pytest.mark.parametrize("fp8", (False, True)) +def test_fuser_ops_with_userbuffers( + *, + world_size: int, + dtype: torch.dtype = torch.bfloat16, + fp8: bool, +) -> None: + """Launch parallel job and run tests""" + + # Skip invalid configurations + if fp8 and not fp8_available: + pytest.skip(reason_for_no_fp8) + + # Parallel job launcher + command = [] + if tex.ubuf_built_with_mpi(): + python_exe = pathlib.Path(sys.executable).resolve() + command.extend(("mpirun", "-np", str(world_size), "--oversubscribe", "--quiet", python_exe)) + else: + command.extend(("torchrun", f"--nproc_per_node={world_size}")) + + # Script invocation + command.extend( + ( + _current_file, + "--parallel", + "--batch-size", + str(world_size), + "--num-heads", + str(world_size), + "--dtype", + str(dtype), + ) + ) + if fp8: + command.append("--fp8") + + # Environment + env = dict(os.environ) + if not tex.device_supports_multicast(): + env["UB_SKIPMC"] = "1" + env["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + env["PYTORCH_JIT"] = "0" + env["NVTE_TORCH_COMPILE"] = "0" + env["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" + + # Launch parallel job + result = subprocess.run(command, check=True, env=env) + + +def main() -> None: + + # Parse command-line arguments + parser = argparse.ArgumentParser() + parser.add_argument("--parallel", action="store_true", help="Run parallel tests") + parser.add_argument("--sequence-length", type=int, default=32) + parser.add_argument("--batch-size", type=int, default=16) + parser.add_argument("--num-heads", type=int, default=16) + parser.add_argument("--head-dim", type=int, default=32) + parser.add_argument("--dtype", type=str, default="bfloat16") + parser.add_argument("--fp8", action="store_true") + args = parser.parse_args() + + # Run parallel tests if needed + if args.parallel: + + # Model config + model_config = ModelConfig( + sequence_length=args.sequence_length, + batch_size=args.batch_size, + num_heads=args.num_heads, + head_dim=args.head_dim, + dtype=str_to_dtype(args.dtype), + fp8=args.fp8, + ) + + # Initialize Userbuffers + group = world_group() # Initialize NCCL + bootstrap_backend = "mpi" if launcher() == "ompi" else "nccl" + userbuffer_configs = { + "fc1_dgrad": {"method": "pipeline"}, # Overlap dgrad RS with dgrad GEMM + } + te.module.base.initialize_ub( + [ + model_config.sequence_length * model_config.batch_size, + model_config.num_heads * model_config.head_dim, + ], + torch.distributed.get_world_size(group), + use_fp8=model_config.fp8, + dtype=model_config.dtype, + bootstrap_backend=bootstrap_backend, + ub_cfgs=userbuffer_configs, + ) + + # Run tests + run_parallel_tests(model_config) + + # Clean up + te.module.base.destroy_ub() + + +if __name__ == "__main__": + main() diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py index 15fb994050..2d863b3bba 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py @@ -11,16 +11,24 @@ import transformer_engine_torch as tex from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn from transformer_engine.pytorch.fp8 import fp8_autocast +from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.common.recipe import DelayedScaling dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} def run_dpa_with_cp( - dtype="bf16", model=None, qkv_format="bshd", kernel_backend="FlashAttention", cp_comm_type="p2p" + dtype="bf16", + model=None, + qkv_format="bshd", + kernel_backend="FlashAttention", + cp_comm_type="p2p", + fp8_mha=False, ): """Test DotProductAttention module with context parallelism""" + # args are passed as strings + fp8_mha = fp8_mha == "True" os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" if kernel_backend == "FlashAttention": @@ -72,7 +80,7 @@ def run_dpa_with_cp( cp_comm_sub_groups.append(sub_group) if dtype == "fp8": - fp8_recipe = DelayedScaling(fp8_dpa=True) + fp8_recipe = DelayedScaling(fp8_dpa=True, fp8_mha=fp8_mha) # instantiate core attn module core_attn = DotProductAttention( @@ -201,7 +209,11 @@ def run_dpa_with_cp( None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1] ), ) - out.backward(dout) + if fp8_mha: + dout_fp8 = Float8Tensor.to_float8(dout, fp8_dtype=tex.DType.kFloat8E5M2) + out.backward(dout_fp8) + else: + out.backward(dout) # run core_attn wit CP q_, k_, v_, dout_, *rest = [ @@ -269,7 +281,11 @@ def run_dpa_with_cp( None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1] ), ) - out_.backward(dout_) + if fp8_mha: + dout_fp8_ = Float8Tensor.to_float8(dout_, fp8_dtype=tex.DType.kFloat8E5M2) + out_.backward(dout_fp8_) + else: + out_.backward(dout_) for x in [out_, q_.grad, k_.grad, v_.grad]: assert torch.all(~torch.isnan(x)) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 4b4eecbf39..4e995dabb1 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -619,14 +619,14 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"] model_configs_layout_thd = { # test: b, h, hg, d, sq, skv, p, mask, bias - "layout_0_1": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"), - "layout_0_2": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"), + "layout_0_1": ModelConfig(3, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"), + "layout_0_2": ModelConfig(8, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"), "layout_0_3": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"), "layout_0_4": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"), "layout_1_1": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), "layout_1_2": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), - "layout_1_3": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "layout_1_4": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "layout_1_3": ModelConfig(1, 16, 1, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "layout_1_4": ModelConfig(8, 16, 1, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), "layout_2_1": ModelConfig(1, 16, 16, 128, 128, 128, 0.0, "padding", "no_bias"), "layout_2_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"), "layout_2_3": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), @@ -644,6 +644,9 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): @pytest.mark.parametrize("qkv_layout", qkv_layouts_thd) def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout): """Test DotProductAttention module with different QKV layouts""" + config = model_configs[model] + if config.num_heads != config.num_gqa_groups and "3" in qkv_layout: + pytest.skip("qkv_layout not applicable for MQA/GQA") pad_between_seqs = True test_dot_product_attention( dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs @@ -1353,8 +1356,6 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, config = model_configs_fp8_vs_f16[model] if _flash_attn_3_is_installed and not is_training: - if RoPE: - pytest.skip("Flash Attention doesn't support FP8 MHA with RoPE.") os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index ea30a4831f..1007d6aa34 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -113,7 +113,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): @pytest.mark.parametrize("model", model_configs_fused_attn.keys()) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) @pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"]) -def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): +@pytest.mark.parametrize("fp8_mha", [False, True]) +def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha): if qkv_format == "thd" and get_device_compute_capability() < (9, 0): pytest.skip("THD format is only supported on sm90+!") if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0): @@ -153,6 +154,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" ) + if dtype != "fp8" and fp8_mha: + pytest.skip("Only fp8 works with fp8_mha=True!") subprocess.run( get_bash_arguments( @@ -162,6 +165,7 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type): qkv_format=qkv_format, kernel_backend="FusedAttention", cp_comm_type=cp_comm_type, + fp8_mha=fp8_mha, ), check=True, ) diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 60a5a1ea99..010050baea 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -13,24 +13,25 @@ LayerNormLinear, LayerNormMLP, Linear, - make_graphed_callables, MultiheadAttention, TransformerLayer, fp8_autocast, fp8_model_init, + make_graphed_callables, ) from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.utils import is_bf16_compatible +import transformer_engine.pytorch.ops as te_ops -# Only run FP8 tests on H100. +# Check if FP8 is supported. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +# Record initial RNG state. seed = 1234 torch.manual_seed(seed) torch.cuda.manual_seed(seed) -# Record initial RNG state from script run. _cpu_rng_state = torch.get_rng_state() _cuda_rng_state = torch.cuda.get_rng_state() @@ -48,17 +49,14 @@ class ModelConfig: model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)} -modules = ["transformer", "layernorm_mlp", "layernorm_linear", "linear", "mha", "dpa"] - -all_boolean = [True, False] - -dtypes = [torch.float32, torch.float16] +# Supported data types +dtypes: List[torch.dtype] = [torch.float32, torch.float16] if is_bf16_compatible(): # bf16 requires sm_80 or higher dtypes.append(torch.bfloat16) def reset_rng_states() -> None: - """revert back to initial RNG state.""" + """Revert to initial RNG state.""" torch.set_rng_state(_cpu_rng_state) torch.cuda.set_rng_state(_cuda_rng_state) @@ -70,64 +68,40 @@ def reset_global_fp8_state(): def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool: - """Ensures two lists are equal.""" + """Check that two lists of tensors match exactly.""" assert len(l1) == len(l2), "Unequal number of outputs." - failed = False - failed_tensors = "" + failure_message = "Output mismatches in:" + failed_tensors = [] for i, (t1, t2) in enumerate(zip(l1, l2)): if not torch.equal(t1, t2): - failed = True - failed_tensors += ( - f" {names[i]}\n" if names is not None else f" tensor at idx={i}\n" - ) - assert not failed, "Output mismatches in:\n" + failed_tensors + failure_message += "\n " + if names is None: + failure_message += f"tensor at idx={i}" + else: + failure_message += names[i] + failed_tensors.append((t1, t2)) + if failed_tensors: + print(failure_message) + t1, t2 = failed_tensors[0] + torch.testing.assert_close(t1, t2, rtol=0, atol=0) def generate_data( - config: ModelConfig, + model_config: ModelConfig, dtype: torch.dtype, - dpa: bool = False, warmup: bool = False, - return_grad_output: bool = False, -) -> Tuple[List[torch.Tensor], torch.Tensor]: + requires_grad: bool = True, +) -> torch.Tensor: """Generate synthetic data.""" gen_func = torch.ones if warmup else torch.randn - if dpa: - inputs = [ - gen_func( - config.sequence_length, - config.batch_size, - config.num_heads, - config.kv_channels, - device="cuda", - requires_grad=True, - dtype=dtype, - ) - for _ in range(3) - ] - else: - inputs = [ - gen_func( - config.sequence_length, - config.batch_size, - config.hidden_size, - device="cuda", - requires_grad=True, - dtype=dtype, - ) - ] - - if not return_grad_output: - return inputs - - grad_output = torch.randn( - config.sequence_length, - config.batch_size, - config.hidden_size, + return gen_func( + model_config.sequence_length, + model_config.batch_size, + model_config.hidden_size, device="cuda", + requires_grad=requires_grad, dtype=dtype, ) - return inputs, grad_output def get_outputs( @@ -157,30 +131,44 @@ def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor: return x +# Supported modules +_test_cuda_graphs_modules: List[str] = [ + "transformer", + "layernorm_mlp", + "layernorm_linear", + "linear", + "mha", + "linear_op", +] + + def _test_cuda_graphs( *, - config: ModelConfig, + graph_mode: str, + module: str, + model_config: ModelConfig, num_layers: int, dtype: torch.dtype, fp8: bool, fp8_params: bool, fp8_weight_caching: bool, - module: str, - graph_mode: str, ) -> List[torch.Tensor]: """Helper function for CUDA graph test.""" reset_rng_states() FP8GlobalStateManager.reset() - dpa = module == "dpa" + # Operation-based API does not support FP8 weight caching. + if module == "linear_op": + fp8_weight_caching = False + + # Create modules. with fp8_model_init(enabled=fp8_params): - # Create modules. if module == "transformer": modules = [ TransformerLayer( - config.hidden_size, - config.hidden_size, - config.num_heads, + model_config.hidden_size, + model_config.hidden_size, + model_config.num_heads, hidden_dropout=0.0, attention_dropout=0.0, fuse_qkv_params=True, @@ -190,37 +178,56 @@ def _test_cuda_graphs( ] elif module == "layernorm_mlp": modules = [ - LayerNormMLP(config.hidden_size, config.hidden_size, params_dtype=dtype) + LayerNormMLP( + model_config.hidden_size, + model_config.hidden_size, + params_dtype=dtype, + ) for _ in range(num_layers) ] elif module == "layernorm_linear": modules = [ - LayerNormLinear(config.hidden_size, config.hidden_size, params_dtype=dtype) + LayerNormLinear( + model_config.hidden_size, + model_config.hidden_size, + params_dtype=dtype, + ) for _ in range(num_layers) ] elif module == "mha": modules = [ MultiheadAttention( - config.hidden_size, - config.num_heads, + model_config.hidden_size, + model_config.num_heads, attention_dropout=0.0, params_dtype=dtype, fuse_qkv_params=True, ) for _ in range(num_layers) ] - elif dpa: - assert config.hidden_size % config.num_heads == 0, "Err." - assert num_layers == 1, "Err." + elif module == "linear": modules = [ - DotProductAttention(config.num_heads, config.kv_channels, attention_dropout=0.0) + Linear( + model_config.hidden_size, + model_config.hidden_size, + device="cuda", + params_dtype=dtype, + ) for _ in range(num_layers) ] - else: + elif module == "linear_op": modules = [ - Linear(config.hidden_size, config.hidden_size, device="cuda", params_dtype=dtype) + te_ops.Sequential( + te_ops.Linear( + model_config.hidden_size, + model_config.hidden_size, + dtype=dtype, + ), + ) for _ in range(num_layers) ] + else: + raise ValueError(f"Unknown module type ({module})") # Initialize gradient buffers. for module in modules: @@ -230,111 +237,208 @@ def _test_cuda_graphs( # Generate model and wrap API to return graphed version. if graph_mode == "full": # Graph entire model at once. - model = modules[0] if dpa else torch.nn.Sequential(*modules) + model = torch.nn.Sequential(*modules) model = make_graphed_callables( model, - generate_data(config, dtype, dpa=dpa, warmup=True), + (generate_data(model_config, dtype, warmup=True),), num_warmup_iters=10, fp8_enabled=fp8, fp8_weight_caching=fp8_weight_caching, ) elif graph_mode == "individual": - # Graph individual modules + # Graph individual modules. modules = [ make_graphed_callables( module, - generate_data(config, dtype, dpa=dpa, warmup=True), + (generate_data(model_config, dtype, warmup=True),), num_warmup_iters=10, fp8_enabled=fp8, fp8_weight_caching=fp8_weight_caching, ) for module in modules ] - model = modules[0] if dpa else _Sequential(*modules) + model = _Sequential(*modules) else: - model = modules[0] if dpa else _Sequential(*modules) + model = _Sequential(*modules) # Optimizer. - if not dpa: - optimizer = torch.optim.SGD(model.parameters(), lr=0.001) + optimizer = torch.optim.SGD(model.parameters(), lr=0.001) - # Launch. + # Training steps. for _ in range(3): - if not dpa: - optimizer.zero_grad(set_to_none=False) + optimizer.zero_grad(set_to_none=False) for grad_accumulation_step in range(2): - inputs, grad_output = generate_data(config, dtype, dpa=dpa, return_grad_output=True) + input_ = generate_data(model_config, dtype) + grad_output = generate_data(model_config, dtype, requires_grad=False) with fp8_autocast(enabled=fp8): kwargs = {} if fp8_weight_caching: kwargs["is_first_microbatch"] = grad_accumulation_step == 0 - output = model(*inputs, **kwargs) + output = model(input_, **kwargs) output.backward(grad_output) - if not dpa: - optimizer.step() + optimizer.step() return get_outputs(model, output) +@pytest.mark.parametrize("module", _test_cuda_graphs_modules) @pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.parametrize("model", model_configs.keys()) -@pytest.mark.parametrize("num_layers", [1, 3]) -@pytest.mark.parametrize("fp8", all_boolean) -@pytest.mark.parametrize("fp8_params", all_boolean) -@pytest.mark.parametrize("fp8_weight_caching", all_boolean) -@pytest.mark.parametrize("module", modules) -def test_gpt_make_graphed_callables( +@pytest.mark.parametrize("fp8", (False, True)) +@pytest.mark.parametrize("fp8_params", (False, True)) +def test_make_graphed_callables( + *, + module: str, + model_config: str = "small", + num_layers: int = 3, dtype: torch.dtype, - model: str, - num_layers: int, fp8: bool, fp8_params: bool, - fp8_weight_caching: bool, - module: str, + fp8_weight_caching: bool = False, ) -> None: + + # Skip invalid configurations. if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) if fp8_params and not fp8: pytest.skip("FP8 needed for FP8 parameters.") if fp8_weight_caching and not fp8: pytest.skip("FP8 needed for FP8 parameters.") - if module == "dpa" and num_layers > 1: - pytest.skip("Max 1 layer for DPA.") - - config = model_configs[model] + # Run model with different CUDA graph settings. + model_config = model_configs[model_config] kwargs = dict( - config=config, + module=module, + model_config=model_config, num_layers=num_layers, dtype=dtype, fp8=fp8, fp8_params=fp8_params, fp8_weight_caching=fp8_weight_caching, - module=module, ) outputs = _test_cuda_graphs(graph_mode="none", **kwargs) graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs) graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs) - # Check that results match + # Check that results match. assert_all_equal(outputs, graph_outputs_mode1) assert_all_equal(outputs, graph_outputs_mode2) -def _test_cuda_graphs_with_kwargs( +_test_make_graphed_callables_with_fp8_weight_caching_modules = [ + "transformer", + "layernorm_mlp", + "layernorm_linear", + "linear", + "mha", +] + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +@pytest.mark.parametrize( + "module", + _test_make_graphed_callables_with_fp8_weight_caching_modules, +) +@pytest.mark.parametrize("fp8_params", (False, True)) +def test_make_graphed_callables_with_fp8_weight_caching( *, - config: ModelConfig, + module: str, + fp8_params: bool, +) -> None: + test_make_graphed_callables( + module=module, + dtype=torch.float32, + fp8=True, + fp8_params=fp8_params, + fp8_weight_caching=True, + ) + + +def generate_data_for_dot_product_attention( + model_config: ModelConfig, dtype: torch.dtype, + warmup: bool = False, +) -> List[torch.Tensor]: + """Generate synthetic data for dot product attention.""" + gen_func = torch.ones if warmup else torch.randn + return [ + gen_func( + model_config.sequence_length, + model_config.batch_size, + model_config.num_heads, + model_config.kv_channels, + device="cuda", + requires_grad=True, + dtype=dtype, + ) + for _ in range(3) + ] + + +def _test_cuda_graphs_with_dot_product_attention( + *, with_graph: bool, + model_config: ModelConfig, + dtype: torch.dtype, ) -> List[torch.Tensor]: - """Simulate Megatron-LM interleaved pipeline parallelism.""" + """Helper function for CUDA graph test.""" + reset_rng_states() + FP8GlobalStateManager.reset() + + # Create dot product attention module. + assert model_config.hidden_size % model_config.num_heads == 0 + model = DotProductAttention( + model_config.num_heads, + model_config.kv_channels, + attention_dropout=0.0, + ) + + # Graph model if needed. + if with_graph: + model = make_graphed_callables( + model, + generate_data_for_dot_product_attention(model_config, dtype, warmup=True), + num_warmup_iters=10, + fp8_enabled=False, + ) + + # Forward and backward passes. + for _ in range(3): + inputs = generate_data_for_dot_product_attention(model_config, dtype) + grad_output = generate_data(model_config, dtype, requires_grad=False) + output = model(*inputs) + output.backward(grad_output) + + return get_outputs(model, output) + + +@pytest.mark.parametrize("dtype", dtypes) +def test_make_graphed_callables_with_dot_product_attention( + *, + model_config: str = "small", + dtype: torch.dtype, +) -> None: + """Test CUDA graphs with dot product attention.""" + model_config = model_configs[model_config] + kwargs = dict(model_config=model_config, dtype=dtype) + outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=False, **kwargs) + graph_outputs = _test_cuda_graphs_with_dot_product_attention(with_graph=True, **kwargs) + assert_all_equal(outputs, graph_outputs) + + +def _test_cuda_graphs_with_kwargs( + *, + with_graph: bool, + model_config: ModelConfig, + dtype: torch.dtype, +) -> List[torch.Tensor]: + """Helper function for CUDA graph test with keyword arguments.""" reset_rng_states() # Initialize model. model = TransformerLayer( - config.hidden_size, - config.hidden_size, - config.num_heads, + model_config.hidden_size, + model_config.hidden_size, + model_config.num_heads, hidden_dropout=0.0, attention_dropout=0.0, self_attn_mask_type="arbitrary", @@ -349,13 +453,18 @@ def _test_cuda_graphs_with_kwargs( # Make graphed version of model if needed. if with_graph: attn_mask = torch.zeros( - (config.batch_size, 1, config.sequence_length, config.sequence_length), + ( + model_config.batch_size, + 1, + model_config.sequence_length, + model_config.sequence_length, + ), dtype=torch.bool, device="cuda", ) model = make_graphed_callables( model, - generate_data(config, dtype, warmup=True), + (generate_data(model_config, dtype, warmup=True),), sample_kwargs=dict(attention_mask=attn_mask), allow_unused_input=True, ) @@ -367,14 +476,20 @@ def _test_cuda_graphs_with_kwargs( for _ in range(3): optimizer.zero_grad(set_to_none=False) for grad_accumulation_step in range(2): - inputs, grad_output = generate_data(config, dtype, return_grad_output=True) + input_ = generate_data(model_config, dtype) + grad_output = generate_data(model_config, dtype, requires_grad=False) attn_mask = torch.randint( 2, - (config.batch_size, 1, config.sequence_length, config.sequence_length), + ( + model_config.batch_size, + 1, + model_config.sequence_length, + model_config.sequence_length, + ), dtype=torch.bool, device="cuda", ) - output = model(*inputs, attention_mask=attn_mask) + output = model(input_, attention_mask=attn_mask) output.backward(grad_output) optimizer.step() @@ -382,12 +497,13 @@ def _test_cuda_graphs_with_kwargs( def test_make_graphed_callables_with_kwargs( + *, + model_config: str = "small", dtype: torch.dtype = torch.float32, - model: str = "small", ) -> None: """Test CUDA graphs with keyword arguments.""" - config = model_configs[model] - kwargs = dict(config=config, dtype=dtype) + model_config = model_configs[model_config] + kwargs = dict(model_config=model_config, dtype=dtype) outputs = _test_cuda_graphs_with_kwargs(with_graph=False, **kwargs) graph_outputs = _test_cuda_graphs_with_kwargs(with_graph=True, **kwargs) assert_all_equal(outputs, graph_outputs) @@ -395,9 +511,9 @@ def test_make_graphed_callables_with_kwargs( def _test_cuda_graphs_with_interleaved_pipeline_parallelism( *, - config: ModelConfig, - dtype: torch.dtype, with_graph: bool, + model_config: ModelConfig, + dtype: torch.dtype, ) -> List[torch.Tensor]: """Simulate Megatron-LM interleaved pipeline parallelism.""" reset_rng_states() @@ -411,8 +527,8 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism( model = torch.nn.ModuleList( [ Linear( - config.hidden_size, - config.hidden_size, + model_config.hidden_size, + model_config.hidden_size, params_dtype=dtype, ) for _ in range(num_layers) @@ -430,7 +546,8 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism( } if with_graph: sample_args = tuple( - generate_data(config, dtype, warmup=True) for _ in range(num_layers * num_microbatches) + (generate_data(model_config, dtype, warmup=True),) + for _ in range(num_layers * num_microbatches) ) layer_forwards = make_graphed_callables( tuple(model), @@ -455,9 +572,10 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism( grad_outputs = {} for layer_idx in range(num_layers): for microbatch_idx in range(num_microbatches): - x, dy = generate_data(config, dtype, return_grad_output=True) + x = generate_data(model_config, dtype) + dy = generate_data(model_config, dtype, requires_grad=False) idxs = (layer_idx, microbatch_idx) - inputs[idxs] = x[0] + inputs[idxs] = x grad_outputs[idxs] = dy # Cache for layer outputs. @@ -494,12 +612,13 @@ def backward(layer_idx: int, microbatch_idx: int): def test_make_graphed_callables_with_interleaved_pipeline_parallelism( + *, + model_config: str = "small", dtype: torch.dtype = torch.float16, - model: str = "small", ) -> None: """Test CUDA graphs with Megatron-LM interleaved pipeline parallelism.""" - config = model_configs[model] - kwargs = dict(config=config, dtype=dtype) + model_config = model_configs[model_config] + kwargs = dict(model_config=model_config, dtype=dtype) outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism( with_graph=False, **kwargs, diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index d19fc5a521..4d4eb38342 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -4,6 +4,7 @@ from itertools import product import copy +from contextlib import nullcontext import pytest import torch @@ -174,6 +175,216 @@ def test_frozen_model(self): torch.testing.assert_close(ref_param, tst_param) + def gen_precision_aware_test( + self, + use_fp8_params, + param_dtype, + use_master_weights, + master_weight_dtype, + grad_dtype, + exp_avg_dtype, + exp_avg_sq_dtype, + model_rtol=None, + model_atol=None, + master_rtol=None, + master_atol=None, + skip_assert=False, + ): + build_model_context = nullcontext + build_model_context_args = {} + if use_fp8_params: + build_model_context = fp8_model_init + build_model_context_args["enabled"] = True + + with build_model_context(**build_model_context_args): + model = MultiheadAttention( + hidden_size=1024, + num_attention_heads=16, + layer_number=1, + params_dtype=param_dtype, + fuse_qkv_params=True, + ).cuda() + + ref_params = [] + model_params = [] + + for p in model.parameters(): + if p.requires_grad: + ref_params.append(p.detach().clone().float()) + model_params.append(p) + + options = { + "lr": 1, + "betas": (0.1, 0.25), + "eps": 1e-08, + "weight_decay": 0, + "amsgrad": False, + } + ref_optim = torch.optim.Adam(ref_params, **options) + tst_optim = te.optimizers.FusedAdam( + model_params, + master_weights=use_master_weights, + master_weight_dtype=master_weight_dtype, + exp_avg_dtype=exp_avg_dtype, + exp_avg_sq_dtype=exp_avg_sq_dtype, + use_decoupled_grad=True, + **options, + ) + + def test_one_iteration(ref_optimizer, tst_optimizer): + for p_ref, p in zip(ref_params, model_params): + p_ref.grad = torch.rand_like(p_ref) + p.decoupled_grad = p_ref.grad.clone().to(grad_dtype) + ref_optimizer.step() + tst_optimizer.step() + if use_master_weights: + master_weights_to_fp32 = [ + tst_optim.get_unscaled_state(p, "master_param") for p in model_params + ] + if not skip_assert: + torch.testing.assert_close( + ref_params, + master_weights_to_fp32, + rtol=master_rtol, + atol=master_atol, + equal_nan=True, + ) + ref_params_to_model_dtype = [p.to(param_dtype) for p in ref_params] + if not skip_assert: + torch.testing.assert_close( + ref_params_to_model_dtype, + model_params, + rtol=model_rtol, + atol=model_atol, + equal_nan=True, + ) + + for i in range(self.iters): + test_one_iteration(ref_optim, tst_optim) + + state_dict = tst_optim.state_dict() + tst_optim = te.optimizers.FusedAdam( + model_params, + master_weights=use_master_weights, + master_weight_dtype=master_weight_dtype, + exp_avg_dtype=exp_avg_dtype, + exp_avg_sq_dtype=exp_avg_sq_dtype, + use_decoupled_grad=True, + **options, + ) + tst_optim.load_state_dict(state_dict) + + for i in range(self.iters): + test_one_iteration(ref_optim, tst_optim) + + def test_fp32_no_master(self): + self.gen_precision_aware_test( + use_fp8_params=False, + param_dtype=torch.float32, + use_master_weights=False, + master_weight_dtype=torch.float32, + grad_dtype=torch.float32, + exp_avg_dtype=torch.float32, + exp_avg_sq_dtype=torch.float32, + ) + + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + def test_fp32_master(self): + self.gen_precision_aware_test( + use_fp8_params=False, + param_dtype=torch.bfloat16, + use_master_weights=True, + master_weight_dtype=torch.float32, + grad_dtype=torch.float32, + exp_avg_dtype=torch.float32, + exp_avg_sq_dtype=torch.float32, + ) + + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + def test_fp16_master(self): + self.gen_precision_aware_test( + use_fp8_params=False, + param_dtype=torch.bfloat16, + use_master_weights=True, + master_weight_dtype=torch.half, + grad_dtype=torch.float32, + exp_avg_dtype=torch.float32, + exp_avg_sq_dtype=torch.float32, + master_rtol=2e-3, + master_atol=2e-3, + ) + + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + def test_bf16_grad(self): + self.gen_precision_aware_test( + use_fp8_params=False, + param_dtype=torch.bfloat16, + use_master_weights=True, + master_weight_dtype=torch.float32, + grad_dtype=torch.bfloat16, + exp_avg_dtype=torch.float32, + exp_avg_sq_dtype=torch.float32, + master_rtol=2e-3, + master_atol=2e-3, + ) + + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + def test_fp16_exp_avg(self): + self.gen_precision_aware_test( + use_fp8_params=False, + param_dtype=torch.bfloat16, + use_master_weights=True, + master_weight_dtype=torch.float32, + grad_dtype=torch.float32, + exp_avg_dtype=torch.half, + exp_avg_sq_dtype=torch.float32, + master_rtol=2e-3, + master_atol=2e-3, + ) + + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) + def test_fp8_exp_avg(self): + self.gen_precision_aware_test( + use_fp8_params=False, + param_dtype=torch.bfloat16, + use_master_weights=True, + master_weight_dtype=torch.float32, + grad_dtype=torch.float32, + exp_avg_dtype=torch.uint8, + exp_avg_sq_dtype=torch.float32, + master_rtol=1e-2, + master_atol=1e-2, + ) + + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + def test_fp16_exp_avg_sq(self): + self.gen_precision_aware_test( + use_fp8_params=False, + param_dtype=torch.bfloat16, + use_master_weights=True, + master_weight_dtype=torch.float32, + grad_dtype=torch.float32, + exp_avg_dtype=torch.float32, + exp_avg_sq_dtype=torch.half, + master_rtol=2e-3, + master_atol=2e-3, + ) + + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) + def test_fp8_exp_avg_sq(self): + self.gen_precision_aware_test( + use_fp8_params=False, + param_dtype=torch.bfloat16, + use_master_weights=True, + master_weight_dtype=torch.float32, + grad_dtype=torch.float32, + exp_avg_dtype=torch.float32, + exp_avg_sq_dtype=torch.uint8, + skip_assert=True, + ) + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") def test_bf16_model_weight_cast(self): dtype = torch.bfloat16 @@ -185,12 +396,10 @@ def test_bf16_model_weight_cast(self): fuse_qkv_params=True, ).cuda() ref_params = [] - master_params = [] model_params = [] for p in model.parameters(): if p.requires_grad: ref_params.append(p.detach().clone().float()) - master_params.append(p.detach().clone().float()) model_params.append(p) options = { "lr": 5e-4, @@ -200,12 +409,17 @@ def test_bf16_model_weight_cast(self): "amsgrad": False, } ref_optim = torch.optim.Adam(ref_params, **options) - tst_optim = te.optimizers.FusedAdam(model_params, master_weights=master_params, **options) + tst_optim = te.optimizers.FusedAdam( + model_params, master_weights=True, use_decoupled_grad=True, **options + ) for i in range(self.iters): - self.gen_grad(ref_params, master_params) + for p_ref, p in zip(ref_params, model_params): + p_ref.grad = torch.rand_like(p_ref) + p.decoupled_grad = p_ref.grad.clone() ref_optim.step() tst_optim.step() + master_params = [tst_optim.get_unscaled_state(p, "master_param") for p in model_params] torch.testing.assert_close(ref_params, master_params) model_params_to_fp32 = [p.float() for p in model_params] torch.testing.assert_close( @@ -224,12 +438,10 @@ def test_fp8_model_weight_cast(self): fuse_qkv_params=True, ).cuda() ref_params = [] - master_params = [] model_params = [] for p in model.parameters(): if p.requires_grad: ref_params.append(p.detach().clone().float()) - master_params.append(p.detach().clone().float()) model_params.append(p) options = { "lr": 5e-4, @@ -239,12 +451,17 @@ def test_fp8_model_weight_cast(self): "amsgrad": False, } ref_optim = torch.optim.Adam(ref_params, **options) - tst_optim = te.optimizers.FusedAdam(model_params, master_weights=master_params, **options) + tst_optim = te.optimizers.FusedAdam( + model_params, master_weights=True, use_decoupled_grad=True, **options + ) for i in range(self.iters): - self.gen_grad(ref_params, master_params) + for p_ref, p in zip(ref_params, model_params): + p_ref.grad = torch.rand_like(p_ref) + p.decoupled_grad = p_ref.grad.clone() ref_optim.step() tst_optim.step() + master_params = [tst_optim.get_unscaled_state(p, "master_param") for p in model_params] torch.testing.assert_close(ref_params, master_params) model_params_to_fp32 = [p.float() for p in model_params] torch.testing.assert_close( diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 1d91683ae4..fd2832c1d4 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -10,6 +10,7 @@ import torch import transformer_engine +import transformer_engine.common.recipe import transformer_engine.pytorch as te from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.fp8 import FP8GlobalStateManager @@ -292,7 +293,7 @@ def test_fp8_scale_update( ) # Check that scaling factors match expected - w_amax_ref = max(w_vals[: step + 2]) + w_amax_ref = max(w_vals[: step + 1]) x_amax_ref = max(x_vals[: step + 1]) dy_amax_ref = max(dy_vals[: step + 1]) w_scale_ref = (fp8_format.value.max_fwd / w_amax_ref) / (2**margin) @@ -633,28 +634,78 @@ def test_bias( torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(db_test, b_ref.grad, **tols) - @pytest.mark.parametrize("weight_shape", ((48, 16), (3, 5))) - @pytest.mark.parametrize("in_shape", ((-1,), (5, 1, -1), (2, 2, 4, -1))) - @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("fp8_compute", (False, True)) - @pytest.mark.parametrize("fp8_input", (False, True)) - @pytest.mark.parametrize("fp8_weight", (False, True)) - @pytest.mark.parametrize("fp8_grad_output", (False, True)) - @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) - def test_basic_linear( + @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) + @pytest.mark.parametrize("cast_forward", (False, True)) + @pytest.mark.parametrize("cast_backward", (False, True)) + def test_cast_float8( self, *, - weight_shape: tuple[int, int], - in_shape: Iterable[int], - dtype: torch.dtype, + in_shape: Iterable[int] = (1,), + dtype: torch.dtype = torch.bfloat16, device: torch.device = "cuda", - fp8_compute: bool, - fp8_input: bool, - fp8_weight: bool, - fp8_grad_output: bool, - accumulate_into_main_grad: bool, + cast_forward: bool, + cast_backward: bool, ) -> None: - """GEMM""" + """FP8 cast""" + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + test_is_fp8=True, + ) + x_test = x_test.from_float8().requires_grad_() + dy_ref, dy_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + test_is_fp8=True, + ) + dy_test = dy_test.from_float8() + + # Plain PyTorch implementation + y_ref = x_ref + dx_ref = dy_ref + + # Implementation with fusible operation + op = te_ops.Quantize(forward=cast_forward, backward=cast_backward) + recipe = transformer_engine.common.recipe.DelayedScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + with te.fp8_autocast(fp8_recipe=recipe): + y_test = op(x_test) + y_test.backward(dy_test) + + # Check tensor types + assert is_float8_tensor(y_test) == cast_forward + assert is_float8_tensor(x_test.grad) == cast_backward + + # Check values + tols = dict(rtol=0, atol=0) + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, dx_ref, **tols) + + def _test_basic_linear( + self, + *, + weight_shape: tuple[int, int] = (32, 32), + in_shape: Iterable[int] = (32, -1), + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + fp8_compute: bool = False, + fp8_input: bool = False, + fp8_weight: bool = False, + fp8_output: bool = False, + fp8_grad_output: bool = False, + fp8_grad_input: bool = False, + accumulate_into_main_grad: bool = False, + ) -> None: + """Helper function for tests with GEMM""" # Make input and weight shapes consistent out_features, in_features = weight_shape @@ -662,7 +713,7 @@ def test_basic_linear( out_shape = in_shape[:-1] + [out_features] # Skip invalid configurations - if fp8_compute or fp8_input or fp8_weight or fp8_grad_output: + if fp8_compute or fp8_input or fp8_weight or fp8_output or fp8_grad_output: if not fp8_available: pytest.skip(reason_for_no_fp8) if torch.device(device).type != "cuda": @@ -674,6 +725,10 @@ def test_basic_linear( or out_features % 16 != 0 ): pytest.skip("FP8 GEMMs require dims that are divisible by 16") + if fp8_output and not fp8_compute: + pytest.skip("FP8 output is only supported with FP8 GEMMs") + if fp8_grad_input and not fp8_compute: + pytest.skip("FP8 grad input is only supported with FP8 GEMMs") # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -713,15 +768,23 @@ def test_basic_linear( op.weight.copy_(w_test) del w_test op.weight.main_grad = torch.full_like(op.weight, 0.5, dtype=torch.float32) - with te.fp8_autocast(enabled=fp8_compute): - y_test = op(x_test) + forward = te_ops.Sequential( + te_ops.Quantize(forward=fp8_input, backward=fp8_grad_input), + op, + te_ops.Quantize(forward=fp8_output, backward=fp8_grad_output), + ) + recipe = transformer_engine.common.recipe.DelayedScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + with te.fp8_autocast(enabled=fp8_compute, fp8_recipe=recipe): + y_test = forward(x_test) y_test.backward(dy_test) # Expected numerical error tols = dtype_tols(dtype) if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM - if fp8_compute: + if fp8_compute or fp8_output or fp8_grad_input: tols = dtype_tols( op.weight._fp8_dtype if is_float8_tensor(op.weight) else tex.DType.kFloat8E4M3 ) @@ -750,6 +813,57 @@ def test_basic_linear( ) torch.testing.assert_close(dw_test, w_ref.grad, **tols) + @pytest.mark.parametrize("weight_shape", ((48, 16), (3, 5))) + @pytest.mark.parametrize("in_shape", ((-1,), (5, 1, -1), (2, 2, 4, -1))) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("fp8_compute", (False, True)) + @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) + def test_basic_linear( + self, + *, + weight_shape: tuple[int, int], + in_shape: Iterable[int], + dtype: torch.dtype, + fp8_compute: bool, + accumulate_into_main_grad: bool, + ) -> None: + """GEMM""" + self._test_basic_linear( + weight_shape=weight_shape, + in_shape=in_shape, + dtype=dtype, + fp8_compute=fp8_compute, + accumulate_into_main_grad=accumulate_into_main_grad, + ) + + @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) + @pytest.mark.parametrize("fp8_compute", (False, True)) + @pytest.mark.parametrize("fp8_input", (False, True)) + @pytest.mark.parametrize("fp8_weight", (False, True)) + @pytest.mark.parametrize("fp8_output", (False, True)) + @pytest.mark.parametrize("fp8_grad_output", (False, True)) + @pytest.mark.parametrize("fp8_grad_input", (False, True)) + def test_basic_linear_fp8( + self, + *, + fp8_compute: bool, + fp8_input: bool, + fp8_weight: bool, + fp8_output: bool, + fp8_grad_output: bool, + fp8_grad_input: bool, + ) -> None: + """GEMM with FP8 inputs and outputs""" + self._test_basic_linear( + dtype=torch.bfloat16, + fp8_compute=fp8_compute, + fp8_input=fp8_input, + fp8_weight=fp8_weight, + fp8_output=fp8_output, + fp8_grad_output=fp8_grad_output, + fp8_grad_input=fp8_grad_input, + ) + @pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("fp8_compute", (False, True)) @pytest.mark.parametrize("fp8_weight", (False, True)) @@ -856,6 +970,271 @@ def test_linear( db_test = op.bias.grad.to(dtype=torch.float64, device="cpu") torch.testing.assert_close(db_test, b_ref.grad, **tols) + @pytest.mark.parametrize("weight_shape", ((19,), (16, 4))) + @pytest.mark.parametrize("in_shape", ((-1,), (6, 8, -1))) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("zero_centered_gamma", (False, True)) + @pytest.mark.parametrize("fp8_input", (False, True)) + @pytest.mark.parametrize("fp8_output", (False, True)) + def test_layer_norm( + self, + *, + weight_shape: Iterable[int], + in_shape: Iterable[int], + dtype: torch.dtype, + device: torch.device = "cuda", + eps: float = 0.3, + zero_centered_gamma: bool, + fp8_input: bool, + fp8_output: bool, + ) -> None: + """Layer norm""" + + # Make input and weight shapes consistent + in_shape = list(in_shape)[:-1] + list(weight_shape) + + # Skip invalid configurations + if fp8_input or fp8_output: + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if torch.device(device).type != "cuda": + pytest.skip("FP8 is only supported on CUDA devices") + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8_input, + ) + w_ref, w_test = make_reference_and_test_tensors( + weight_shape, + test_dtype=dtype, + test_device=device, + ) + b_ref, b_test = make_reference_and_test_tensors( + weight_shape, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = torch.nn.functional.layer_norm( + x_ref, + weight_shape, + weight=(w_ref + 1 if zero_centered_gamma else w_ref), + bias=b_ref, + eps=eps, + ) + y_ref.backward(dy_ref) + + # Implementation with fusible operation + op = te_ops.LayerNorm( + weight_shape, + eps=eps, + device=device, + dtype=dtype, + zero_centered_gamma=zero_centered_gamma, + ) + with torch.no_grad(): + op.weight.copy_(w_test) + op.bias.copy_(b_test) + del w_test + del b_test + forward = te_ops.Sequential( + op, + te_ops.Quantize(forward=fp8_output, backward=False), + ) + with te.fp8_autocast(enabled=fp8_output): + y_test = forward(x_test) + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + if fp8_output: + tols = dtype_tols(tex.DType.kFloat8E4M3) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu") + db_test = op.bias.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + torch.testing.assert_close(dw_test, w_ref.grad, **tols) + torch.testing.assert_close(db_test, b_ref.grad, **tols) + + def test_layer_norm_autocast( + self, + *, + weight_shape: Iterable[int] = (32,), + in_shape: Iterable[int] = (32,), + dtype: torch.dtype = torch.float16, + autocast_dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + eps: float = 0.3, + ) -> None: + """Layer norm with PyTorch autocast""" + + # Make input and weight shapes consistent + in_shape = list(in_shape)[:-1] + list(weight_shape) + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=autocast_dtype, + test_device=device, + ) + w_ref, w_test = make_reference_and_test_tensors( + weight_shape, + test_dtype=dtype, + test_device=device, + ) + b_ref, b_test = make_reference_and_test_tensors( + weight_shape, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + in_shape, + test_dtype=autocast_dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = torch.nn.functional.layer_norm( + x_ref, + weight_shape, + weight=w_ref, + bias=b_ref, + eps=eps, + ) + y_ref.backward(dy_ref) + + # Implementation with fusible operation + op = te_ops.LayerNorm( + weight_shape, + eps=eps, + device=device, + dtype=dtype, + ) + with torch.no_grad(): + op.weight.copy_(w_test) + op.bias.copy_(b_test) + del w_test + del b_test + with torch.autocast(device, dtype=autocast_dtype): + y_test = op(x_test) + y_test.backward(dy_test) + + # Check results + assert y_test.dtype == autocast_dtype + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu") + db_test = op.bias.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **dtype_tols(autocast_dtype)) + torch.testing.assert_close(dx_test, x_ref.grad, **dtype_tols(autocast_dtype)) + torch.testing.assert_close(dw_test, w_ref.grad, **dtype_tols(dtype)) + torch.testing.assert_close(db_test, b_ref.grad, **dtype_tols(dtype)) + + @pytest.mark.parametrize("weight_shape", ((19,), (16, 4))) + @pytest.mark.parametrize("in_shape", ((-1,), (6, 8, -1))) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("zero_centered_gamma", (False, True)) + @pytest.mark.parametrize("fp8_input", (False, True)) + @pytest.mark.parametrize("fp8_output", (False, True)) + def test_rmsnorm( + self, + *, + weight_shape: Iterable[int], + in_shape: Iterable[int], + dtype: torch.dtype, + device: torch.device = "cuda", + eps: float = 0.3, + zero_centered_gamma: bool, + fp8_input: bool, + fp8_output: bool, + ) -> None: + """Layer norm""" + + # Make input and weight shapes consistent + in_shape = list(in_shape)[:-1] + list(weight_shape) + + # Skip invalid configurations + if fp8_input or fp8_output: + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if torch.device(device).type != "cuda": + pytest.skip("FP8 is only supported on CUDA devices") + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8_input, + ) + w_ref, w_test = make_reference_and_test_tensors( + weight_shape, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + inner_dims = tuple(range(len(in_shape) - len(weight_shape), len(in_shape))) + var_ref = x_ref.square().sum(dim=inner_dims, keepdim=True) / math.prod(weight_shape) + if zero_centered_gamma: + y_ref = x_ref / torch.sqrt(eps + var_ref) * (1 + w_ref) + else: + y_ref = x_ref / torch.sqrt(eps + var_ref) * w_ref + y_ref.backward(dy_ref) + + # Implementation with fusible operation + op = te_ops.RMSNorm( + weight_shape, + eps=eps, + device=device, + dtype=dtype, + zero_centered_gamma=zero_centered_gamma, + ) + with torch.no_grad(): + op.weight.copy_(w_test) + del w_test + forward = te_ops.Sequential( + op, + te_ops.Quantize(forward=fp8_output, backward=False), + ) + with te.fp8_autocast(enabled=fp8_output): + y_test = forward(x_test) + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + if fp8_output: + tols = dtype_tols(tex.DType.kFloat8E4M3) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + torch.testing.assert_close(dw_test, w_ref.grad, **tols) + @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("device", ("cuda", "cpu")) @pytest.mark.parametrize("fp8", (False, True)) @@ -867,6 +1246,11 @@ def test_add_in_place( device: torch.device, fp8: bool, ) -> None: + """Add two tensors + + Join in compute graph. + + """ # Skip invalid configurations if fp8 and not fp8_available: @@ -927,6 +1311,11 @@ def test_make_extra_output( device: torch.device, fp8: bool, ) -> None: + """Output tensor twice + + Split in compute graph. + + """ # Skip invalid configurations if fp8 and not fp8_available: @@ -973,6 +1362,166 @@ def test_make_extra_output( torch.testing.assert_close(y2_test, y2_ref, rtol=0, atol=0) torch.testing.assert_close(dx_test, x_ref.grad, **tols) + @pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu")) + @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (4, 1, 16))) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("fp8_input", (False, True)) + @pytest.mark.parametrize("fp8_output", (False, True)) + def test_activation( + self, + *, + activation: str, + out_shape: Iterable[int], + dtype: torch.dtype, + device: torch.device = "cuda", + fp8_input: bool, + fp8_output: bool, + ) -> None: + """Activation functions""" + + # Tensor dimensions + in_shape = list(out_shape) + if activation in ("geglu", "reglu", "swiglu"): + in_shape[-1] *= 2 + + # Skip invalid configurations + if fp8_input or fp8_output: + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if torch.device(device).type != "cuda": + pytest.skip("FP8 is only supported on CUDA devices") + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8_input, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref: torch.Tensor + if activation == "gelu": + y_ref = torch.nn.functional.gelu(x_ref, approximate="tanh") + elif activation == "relu": + y_ref = torch.nn.functional.relu(x_ref) + elif activation == "geglu": + x1, x2 = x_ref.chunk(2, dim=-1) + y_ref = torch.nn.functional.gelu(x1, approximate="tanh") * x2 + elif activation == "reglu": + x1, x2 = x_ref.chunk(2, dim=-1) + y_ref = torch.nn.functional.relu(x1) * x2 + elif activation == "swiglu": + x1, x2 = x_ref.chunk(2, dim=-1) + y_ref = torch.nn.functional.silu(x1) * x2 + else: + raise ValueError(f"Unexpected activation function ({activation})") + y_ref.backward(dy_ref) + + # Implementation with fusible operation + make_op = dict( + gelu=te_ops.GELU, + relu=te_ops.ReLU, + geglu=te_ops.GEGLU, + reglu=te_ops.ReGLU, + swiglu=te_ops.SwiGLU, + )[activation] + forward = te_ops.Sequential( + make_op(), + te_ops.Quantize(forward=fp8_output, backward=False), + ) + with te.fp8_autocast(enabled=fp8_output): + y_test = forward(x_test) + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + if fp8_output: + tols = dtype_tols(tex.DType.kFloat8E4M3) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("fp8_output", (False, True)) + @pytest.mark.parametrize("fp8_grad_input", (False, True)) + def test_swiglu( + self, + *, + out_shape: Iterable[int] = (16, 16), + dtype: torch.dtype, + device: torch.device = "cuda", + fp8_output: bool, + fp8_grad_input: bool, + ): + + # Tensor dimensions + in_shape = list(out_shape) + in_shape[-1] *= 2 + + # Skip invalid configurations + fp8 = fp8_output or fp8_grad_input + if fp8: + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if torch.device(device).type != "cuda": + pytest.skip("FP8 is only supported on CUDA devices") + + # FP8 recipe + fp8_recipe = None + if fp8_grad_input: + fp8_recipe = transformer_engine.common.recipe.DelayedScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + x1, x2 = x_ref.chunk(2, dim=-1) + y_ref = torch.nn.functional.silu(x1) * x2 + y_ref.backward(dy_ref) + + # Implementation with fusible operation + forward = te_ops.Sequential( + te_ops.Quantize(forward=False, backward=fp8_grad_input), + te_ops.SwiGLU(), + te_ops.Quantize(forward=fp8_output, backward=False), + ) + with te.fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe): + y_test = forward(x_test) + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + if fp8: + tols = dtype_tols(tex.DType.kFloat8E4M3) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + class TestFusedOps: """Tests for fused operations""" @@ -1106,88 +1655,6 @@ def test_forward_linear_bias_activation( db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu") torch.testing.assert_close(db_test, b_ref.grad, **tols) - @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) - def test_fp8_linear( - self, - *, - in_shape: Iterable[int] = (16, 16), - dtype: torch.dtype = torch.bfloat16, - device: torch.device = "cuda", - ) -> None: - """Adjacent linear ops with FP8 enabled""" - - # Make input and weight shapes consistent - in_shape = tuple(in_shape) - weight_shape = (in_shape[-1], in_shape[-1]) - - # Random data - x_ref, x_test = make_reference_and_test_tensors( - in_shape, - test_dtype=dtype, - test_device=device, - test_is_fp8=True, - ) - w0_ref, w0_test = make_reference_and_test_tensors( - weight_shape, - test_dtype=dtype, - test_device=device, - test_is_fp8=True, - ) - w1_ref, w1_test = make_reference_and_test_tensors( - weight_shape, - test_dtype=dtype, - test_device=device, - test_is_fp8=True, - ) - dy_ref, dy_test = make_reference_and_test_tensors( - in_shape, - test_dtype=dtype, - test_device=device, - requires_grad=False, - ) - - # Plain PyTorch implementation - y_ref = torch.nn.functional.linear(x_ref, w0_ref) - y_ref = torch.nn.functional.linear(y_ref, w1_ref) - y_ref.backward(dy_ref) - - # Implementation with fusible operations - with te.fp8_model_init(enabled=True): - model = te_ops.Sequential( - te_ops.BasicLinear( - in_shape[-1], - in_shape[-1], - device=device, - dtype=dtype, - ), - te_ops.BasicLinear( - in_shape[-1], - in_shape[-1], - device=device, - dtype=dtype, - ), - ) - with torch.no_grad(): - model[0].weight.copy_(w0_test) - model[1].weight.copy_(w1_test) - del w0_test, w1_test - with te.fp8_autocast(enabled=True): - y_test = model(x_test) - y_test.backward(dy_test) - - # Expected numerical error - tols = dtype_tols(model[0].weight._fp8_dtype) - - # Check results - y_test = y_test.to(dtype=torch.float64, device="cpu") - dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") - dw0_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu") - dw1_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu") - torch.testing.assert_close(y_test, y_ref, **tols) - torch.testing.assert_close(dx_test, x_ref.grad, **tols) - torch.testing.assert_close(dw0_test, w0_ref.grad, **tols) - torch.testing.assert_close(dw1_test, w1_ref.grad, **tols) - @pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("fp8_compute", (False, True)) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index c0f45ada4e..c237dbaeb6 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -64,6 +64,7 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq model_configs = { + "small": ModelConfig(128, 1e-5, 8, 36, 4, 128), "126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048), } @@ -110,23 +111,30 @@ def dtype_tols(dtype: torch.dtype) -> Dict[str, float]: def assert_allclose( - l1: List[torch.Tensor], - l2: List[torch.Tensor], - atol: float, + l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float, rtol: float = None ) -> bool: """Ensures two lists are equal.""" assert len(l1) == len(l2), "Unequal number of outputs." for i, (t1, t2) in enumerate(zip(l1, l2)): - result = torch.allclose(t1, t2, atol=atol) + tols = dict(atol=atol) + if rtol is not None: + tols["rtol"] = rtol + result = torch.allclose(t1, t2, **tols) if not result: - diff = torch.abs(t1 - t2).flatten() - m = torch.argmax(diff) - msg = ( - f"Outputs not close enough in tensor at idx={i}. " - f"Location of the maximum difference: {m.item()} " - f"with {t1.flatten()[m].item()} vs {t2.flatten()[m].item()} " - f"(diff {diff[m].item()})." - ) + diff = torch.abs(t1 - t2) + tol = atol + (rtol * torch.abs(t2)) + exceed_mask = diff > tol + if exceed_mask.any(): + indices = torch.nonzero(exceed_mask, as_tuple=True) + max_diff = diff[exceed_mask].max() + max_idx = (diff[exceed_mask] == max_diff).nonzero(as_tuple=True)[0][0] + max_location = [idx[max_idx].item() for idx in indices] + msg = ( + f"Outputs not close enough in tensor at idx={i}. " + f"Maximum difference at location {max_location} " + f"with {t1[exceed_mask][max_idx].item()} vs {t2[exceed_mask][max_idx].item()} " + f"(diff {max_diff.item()})." + ) raise AssertionError(msg) @@ -526,7 +534,7 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", all_boolean) @pytest.mark.parametrize("fp8_model_params", all_boolean) def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_params): @@ -631,7 +639,7 @@ def _test_e2e_full_recompute( @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", all_boolean) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_reentrant", all_boolean) @@ -764,7 +772,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", ["126m"]) def test_gpt_checkpointing(dtype, bs, model): config = model_configs[model] outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False) @@ -809,7 +817,7 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config): @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("parallel_attention_mlp", all_boolean) def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp): config = model_configs[model] @@ -868,11 +876,25 @@ def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp): te_outputs = _test_e2e_gpt_accuracy(te_gpt, bs, dtype, config) torch_outputs = _test_e2e_gpt_accuracy(torch_gpt, bs, dtype, config) + atol = { + torch.float32: 5e-3, + torch.half: 5e-2, + torch.bfloat16: 1e-1, + } + # Check output. - if dtype == torch.float32: - assert_allclose(te_outputs[0], torch_outputs[0], 5e-3) - else: - assert_allclose(te_outputs[0], torch_outputs[0], 5e-2) + assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype]) + + # Check gradients, only for small model + if model == "small": + atol[torch.float32] = 5e-2 + rtol = { + torch.float32: 1e-2, + torch.half: 1e-2, + torch.bfloat16: 1e-2, + } + for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]): + assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True): @@ -906,7 +928,7 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True): @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("mask_type", mask_types) def test_mha_accuracy(dtype, bs, model, mask_type): config = model_configs[model] @@ -947,6 +969,21 @@ def test_mha_accuracy(dtype, bs, model, mask_type): else: assert_allclose(te_outputs[0], torch_outputs[0], 5e-2) + # Check gradients, only for small model + if model == "small": + atol = { + torch.float32: 5e-2, + torch.half: 5e-2, + torch.bfloat16: 5e-2, + } + rtol = { + torch.float32: 1e-2, + torch.half: 1e-2, + torch.bfloat16: 1e-2, + } + for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]): + assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) + def _test_granular_accuracy(block, bs, dtype, config): reset_rng_states() @@ -1002,7 +1039,7 @@ def _test_dpa_accuracy(block, bs, dtype, config): @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", ["126m"]) def test_dpa_accuracy(dtype, bs, model): config = model_configs[model] @@ -1034,10 +1071,13 @@ def test_dpa_accuracy(dtype, bs, model): else: assert_allclose(te_outputs[0], torch_outputs[0], 5e-2) + for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]): + assert_allclose(te_output, torch_output, atol=5e-2, rtol=1e-2) + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", ["small"]) def test_linear_accuracy(dtype, bs, model): config = model_configs[model] @@ -1066,15 +1106,20 @@ def test_linear_accuracy(dtype, bs, model): torch_outputs = _test_granular_accuracy(torch_linear, bs, dtype, config) # Check output. - if dtype == torch.float32: - assert_allclose(te_outputs[0], torch_outputs[0], 5e-3) - else: - assert_allclose(te_outputs[0], torch_outputs[0], 5e-2) + if model == "small": + tolerance = 5e-3 if dtype == torch.float32 else 5e-2 + rtol = { + torch.float32: 1.3e-6, + torch.half: 1e-2, + torch.bfloat16: 2e-2, + } + for te_output, torch_output in zip(te_outputs, torch_outputs): + assert_allclose(te_output, torch_output, tolerance, rtol[dtype]) @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("eps", [1e-1, 1e-3, 1e-5, 1e-7]) @pytest.mark.parametrize("zero_centered_gamma", all_boolean) def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma): @@ -1102,18 +1147,29 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma): te_outputs = _test_granular_accuracy(te_rmsnorm, bs, dtype, config) torch_outputs = _test_granular_accuracy(torch_rmsnorm, bs, dtype, config) - # Check output. atol = { torch.float32: 1e-7, torch.half: 2e-3, torch.bfloat16: 2e-2, } + + # Check output. assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype]) + atol[torch.float32] = 2e-3 + rtol = { + torch.float32: 1.3e-6, + torch.half: 1e-3, + torch.bfloat16: 1.6e-2, + } + # Check gradients + for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]): + assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("eps", [1e-1, 1e-3, 1e-5, 1e-7]) @pytest.mark.parametrize("zero_centered_gamma", all_boolean) def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma): @@ -1142,18 +1198,29 @@ def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma): te_outputs = _test_granular_accuracy(te_layernorm, bs, dtype, config) torch_outputs = _test_granular_accuracy(torch_layernorm, bs, dtype, config) - # Check output. atol = { torch.float32: 1e-7, torch.half: 2e-3, torch.bfloat16: 2e-2, } + + # Check output. assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype]) + rtol = { + torch.float32: 1.3e-6, + torch.half: 1e-3, + torch.bfloat16: 1.6e-2, + } + atol[torch.float32] = 1e-4 + # Check gradients + for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]): + assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("zero_centered_gamma", all_boolean) def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centered_gamma): @@ -1195,18 +1262,34 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere te_outputs = _test_granular_accuracy(te_ln_linear, bs, dtype, config) torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config) - # Check output. atol = { torch.float32: 2.5e-4, torch.half: 2e-3, torch.bfloat16: 2e-2, } + + # Check output. assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype]) + if model == "small": + atol = { + torch.float32: 1e-3, + torch.half: 5e-2, + torch.bfloat16: 5e-2, + } + rtol = { + torch.float32: 1e-3, + torch.half: 4e-2, + torch.bfloat16: 4e-2, + } + # Check gradients + for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]): + assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("activation", all_activations) @pytest.mark.parametrize("normalization", all_normalizations) def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization): @@ -1246,11 +1329,26 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization): te_outputs = _test_granular_accuracy(te_ln_mlp, bs, dtype, config) torch_outputs = _test_granular_accuracy(torch_ln_mlp, bs, dtype, config) + atol = { + torch.float32: 2e-2, + torch.half: 5e-2, + torch.bfloat16: 5e-2, + } + # Check output. - if dtype == torch.float32: - assert_allclose(te_outputs[0], torch_outputs[0], 5e-3) - else: - assert_allclose(te_outputs[0], torch_outputs[0], 5e-2) + assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype]) + + # Check gradients, only for small model + rtol = { + torch.float32: 1e-3, + torch.half: 1e-2, + torch.bfloat16: 4e-2, + } + atol[torch.half] = 2e-1 + atol[torch.bfloat16] = 2e-1 + if model == "small": + for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]): + assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False): @@ -1301,7 +1399,7 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("num_gemms", [3, 6]) @pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", all_boolean) @pytest.mark.parametrize("fp8_model_params", all_boolean) def test_grouped_linear_accuracy( @@ -1361,7 +1459,7 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode): dtype=torch.float32, num_gemms=6, bs=2, - model=list(model_configs.keys())[0], + model="126m", fp8=True, fp8_model_params=True, parallel_mode=parallel_mode, @@ -1374,7 +1472,7 @@ def test_grouped_linear_accuracy_single_gemm(): dtype=torch.float32, num_gemms=1, bs=2, - model=list(model_configs.keys())[0], + model="126m", fp8=True, fp8_model_params=True, ) @@ -1475,7 +1573,7 @@ def _generate_random_numbers(n, total_sum): @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("num_gemms", [3, 6]) @pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", [True]) @pytest.mark.parametrize("fp8_model_params", all_boolean) def test_padding_grouped_linear_accuracy( @@ -1594,7 +1692,7 @@ def train_step(): @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", ["126m"]) def test_gpt_cuda_graph(dtype, bs, model): config = model_configs[model] @@ -1686,7 +1784,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params): @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", ["126m"]) def test_gpt_fp8_parameters(dtype, bs, model): if not fp8_available: pytest.skip(reason_for_no_fp8) @@ -1710,7 +1808,7 @@ def test_gpt_fp8_parameters(dtype, bs, model): @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model", model_configs.keys()) +@pytest.mark.parametrize("model", ["126m"]) def test_transformer_layer_hidden_states_format(dtype, bs, model): config = model_configs[model] diff --git a/tests/pytorch/test_torch_save_load.py b/tests/pytorch/test_torch_save_load.py index 6af7ede234..7bf8fb99d5 100644 --- a/tests/pytorch/test_torch_save_load.py +++ b/tests/pytorch/test_torch_save_load.py @@ -17,7 +17,9 @@ import pytest import torch +import transformer_engine.common import transformer_engine.pytorch as te +import transformer_engine.pytorch.ops as te_ops import transformer_engine_torch as tex from transformer_engine.pytorch.cpp_extensions import fp8_gemm, cast_to_fp8 from transformer_engine.pytorch.fp8 import FP8GlobalStateManager @@ -287,3 +289,186 @@ def test_fp8_model_checkpoint( torch.testing.assert_close( model.weight._scale_inv.item(), fp8_meta_fwd_ref["scale_inv"][meta_index].item() ) + + +@pytest.mark.parametrize("fp8", (False, True)) +@pytest.mark.parametrize("save_fp8_model", (False, True)) +@pytest.mark.parametrize("load_fp8_model", (False, True)) +def test_sequential_model( + *, + in_shape: Iterable[int] = (16, 16), + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + save_steps: int = 2, + load_steps: int = 2, + fp8: bool, + save_fp8_model: bool, + load_fp8_model: bool, +) -> None: + + # Skip invalid configurations + if fp8 or save_fp8_model or load_fp8_model: + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if torch.device(device).type != "cuda": + pytest.skip("FP8 is only supported on CUDA devices") + + # FP8 recipe + margin = 2 + fp8_format = transformer_engine.common.recipe.Format.E4M3 + recipe = transformer_engine.common.recipe.DelayedScaling( + margin=margin, + fp8_format=fp8_format, + amax_history_len=8, + amax_compute_algo="max", + ) + + # Construct model to save to checkpoint + with te.fp8_model_init(enabled=save_fp8_model): + model = te_ops.Sequential( + te_ops.Linear(in_shape[-1], in_shape[-1], device=device, dtype=dtype), + ) + with torch.no_grad(): + torch.rand(model[0].weight.size(), out=model[0].weight) + torch.rand(model[0].bias.size(), out=model[0].bias) + + # Synthetic data + xs_ref = [ + torch.rand(in_shape, dtype=dtype, device=device) for _ in range(save_steps + load_steps) + ] + dys_ref = [ + torch.rand(in_shape, dtype=dtype, device=device) for _ in range(save_steps + load_steps) + ] + + def train_step( + model: te_ops.Sequential, + x: torch.Tensor, + dy: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Helper function to perform training step""" + x = x.detach().clone().requires_grad_() + dy = dy.detach().clone() + with te.fp8_autocast(enabled=fp8, fp8_recipe=recipe): + y = model(x) + y.backward(dy) + with torch.no_grad(): + for param in model.parameters(): + param += 0.125 + return ( + y.detach().clone(), + x.grad.detach().clone(), + model[0].weight.detach().float().clone(), + ) + + # Initial training steps with saved model + ys_ref = [] + dxs_ref = [] + ws_ref = [] + for step in range(save_steps): + y, dx, w = train_step(model, xs_ref[step], dys_ref[step]) + ys_ref.append(y) + dxs_ref.append(dx) + ws_ref.append(w) + + # Keep track of FP8 metadata if needed + fp8_meta_ref = dict(input={}, param={}, grad_output={}) + if fp8: + for fp8_meta_type, fp8_meta_key in ( + ("input", "scaling_fwd"), + ("param", "scaling_fwd"), + ("grad_output", "scaling_bwd"), + ): + m_model = model[0].basic_ops[0].get_fp8_meta(fp8_meta_type)[fp8_meta_key] + m_ref = fp8_meta_ref[fp8_meta_type] + m_ref["amax"] = m_model.amax_history.detach().clone() + m_ref["scale"] = m_model.scale.detach().clone() + m_ref["scale_inv"] = m_model.scale_inv.detach().clone() + del m_model, m_ref + + # Save checkpoint + byte_stream = io.BytesIO() + torch.save(model.state_dict(), byte_stream) + model_bytes = byte_stream.getvalue() + del byte_stream + + # More training steps with saved model + for step in range(save_steps, save_steps + load_steps): + y, dx, w = train_step(model, xs_ref[step], dys_ref[step]) + ys_ref.append(y) + dxs_ref.append(dx) + ws_ref.append(w) + + # Disturb and destroy model + with torch.no_grad(): + for param in model.parameters(): + param.zero_() + model[0].basic_ops[0]._fp8_metas = None + del model + + # Construct new model to load from checkpoint + with te.fp8_model_init(enabled=load_fp8_model): + model = te_ops.Sequential( + te_ops.Linear(in_shape[-1], in_shape[-1], device=device, dtype=dtype), + ) + + # Tolerances for numerical checks + tols = {} + if fp8 or save_fp8_model or load_fp8_model: + tols = dict(rtol=0.125, atol=0.0675) # fp8e4me3 epsilon = 0.0625 + exact_tols = dict(rtol=0, atol=0) + + # Training steps with dummy data + for step in range(save_steps): + y, dx, w = train_step( + model, + torch.zeros_like(xs_ref[step]), + torch.zeros_like(dys_ref[step]), + ) + + # Make sure results don't match saved model + with pytest.raises(AssertionError): + torch.testing.assert_close(y, ys_ref[step], **tols) + with pytest.raises(AssertionError): + torch.testing.assert_close(dx, dxs_ref[step], **tols) + with pytest.raises(AssertionError): + torch.testing.assert_close(w, ws_ref[step], **tols) + + # Make sure new model's FP8 metadata doesn't match saved model + if fp8: + for fp8_meta_type, fp8_meta_key in ( + ("input", "scaling_fwd"), + ("param", "scaling_fwd"), + ("grad_output", "scaling_bwd"), + ): + m_model = model[0].basic_ops[0].get_fp8_meta(fp8_meta_type)[fp8_meta_key] + m_ref = fp8_meta_ref[fp8_meta_type] + with pytest.raises(AssertionError): + torch.testing.assert_close(m_model.amax_history, m_ref["amax"], **exact_tols) + with pytest.raises(AssertionError): + torch.testing.assert_close(m_model.scale, m_ref["scale"], **exact_tols) + with pytest.raises(AssertionError): + torch.testing.assert_close(m_model.scale_inv, m_ref["scale_inv"], **exact_tols) + + # Load checkpoint + model.load_state_dict(torch.load(io.BytesIO(model_bytes))) + del model_bytes + + # Check that new model's FP8 metadata matches saved model + if fp8: + for fp8_meta_type, fp8_meta_key in ( + ("input", "scaling_fwd"), + ("param", "scaling_fwd"), + ("grad_output", "scaling_bwd"), + ): + m_model = model[0].basic_ops[0].get_fp8_meta(fp8_meta_type)[fp8_meta_key] + m_ref = fp8_meta_ref[fp8_meta_type] + torch.testing.assert_close(m_model.amax_history, m_ref["amax"], **exact_tols) + torch.testing.assert_close(m_model.scale, m_ref["scale"], **exact_tols) + torch.testing.assert_close(m_model.scale_inv, m_ref["scale_inv"], **exact_tols) + + # More training steps with loaded model + for step in range(save_steps, save_steps + load_steps): + y, dx, w = train_step(model, xs_ref[step], dys_ref[step]) + torch.testing.assert_close(y, ys_ref[step], **tols) + torch.testing.assert_close(dx, dxs_ref[step], **tols) + torch.testing.assert_close(w, ws_ref[step], **tols) diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py new file mode 100644 index 0000000000..a8b181a187 --- /dev/null +++ b/tests/pytorch/utils.py @@ -0,0 +1,85 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from __future__ import annotations + +import torch + +import transformer_engine +import transformer_engine.pytorch as te +import transformer_engine_torch as tex + + +def str_to_dtype(dtype: str | torch.dtype) -> torch.dtype: + """Convert type name to PyTorch dtype""" + if isinstance(dtype, torch.dtype): + return dtype + name = str(dtype).strip().lower() + if name.startswith("torch."): + name = name.replace("torch.", "", 1) + if name.startswith("fp"): + name = name.replace("fp", "float", 1) + dtype = dict( + float32=torch.float32, + float=torch.float32, + float64=torch.float64, + double=torch.float64, + float16=torch.float16, + half=torch.float16, + bfloat16=torch.bfloat16, + bf16=torch.bfloat16, + float8_e4m3fn=torch.float8_e4m3fn, + float8_e4m3=torch.float8_e4m3fn, + float8e4m3=torch.float8_e4m3fn, + float8=torch.float8_e4m3fn, + float8_e5m2=torch.float8_e5m2, + float8e5m2=torch.float8_e5m2, + uint8=torch.uint8, + byte=torch.uint8, + int8=torch.int8, + char=torch.int8, + int16=torch.int16, + short=torch.int16, + int32=torch.int32, + int=torch.int32, + int64=torch.int64, + long=torch.int64, + bool=torch.bool, + )[name] + return dtype + + +def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: + """Estimated numerical error for a datatype + + Based on tolerances for torch.testing.assert_close. + + """ + + # Transformer Engine dtypes + if isinstance(dtype, tex.DType): + dtype = { + tex.DType.kByte: torch.uint8, + tex.DType.kInt32: torch.int32, + tex.DType.kFloat32: torch.float32, + tex.DType.kFloat16: torch.half, + tex.DType.kBFloat16: torch.bfloat16, + tex.DType.kFloat8E4M3: torch.float8_e4m3fn, + tex.DType.kFloat8E5M2: torch.float8_e5m2, + }[dtype] + + # PyTorch dtypes + if dtype == torch.float16: + return dict(rtol=1e-3, atol=1e-5) + if dtype == torch.bfloat16: + return dict(rtol=1.6e-2, atol=1e-5) + if dtype == torch.float32: + return dict(rtol=1.3e-6, atol=1e-5) + if dtype == torch.float64: + return dict(rtol=1e-7, atol=1e-7) + if dtype == torch.float8_e4m3fn: + return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625 + if dtype == torch.float8_e5m2: + return dict(rtol=0.25, atol=0.125) # epsilon = 0.152 + raise ValueError(f"Unsupported dtype ({dtype})") diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index cabb2e2aea..3784689f9a 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -46,7 +46,7 @@ find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) include_directories(${PROJECT_SOURCE_DIR}/..) set(transformer_engine_SOURCES) list(APPEND transformer_engine_SOURCES - pycudnn.cpp + cudnn_utils.cpp transformer_engine.cpp common.cu transpose/cast_transpose.cu @@ -80,7 +80,11 @@ list(APPEND transformer_engine_SOURCES fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu fused_rope/fused_rope.cu - recipe/delayed_scaling.cu) + recipe/delayed_scaling.cu + comm_gemm_overlap/userbuffers/ipcsocket.cc + comm_gemm_overlap/userbuffers/userbuffers-host.cpp + comm_gemm_overlap/userbuffers/userbuffers.cu + comm_gemm_overlap/comm_gemm_overlap.cpp) add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") @@ -93,6 +97,15 @@ target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") +# Compiling Userbuffers with native MPI bootstrapping requires linking against MPI +option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF) +if (NVTE_UB_WITH_MPI) + find_package(MPI REQUIRED) + target_link_libraries(transformer_engine PUBLIC MPI::MPI_CXX) + target_include_directories(transformer_engine PRIVATE ${MPI_CXX_INCLUDES}) + target_compile_definitions(transformer_engine PUBLIC NVTE_UB_WITH_MPI) +endif() + # Hack to enable dynamic loading in cuDNN frontend target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp new file mode 100644 index 0000000000..a663385b68 --- /dev/null +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -0,0 +1,994 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include + +#include +#include + +#include "common/common.h" +#include "common/util/cuda_driver.h" +#include "common/util/cuda_runtime.h" +#include "common/util/logging.h" +#include "common/util/system.h" +#include "userbuffers/userbuffers.h" + +#define HALF_BYTES 2 +#define UB_MAX_SM 32 + +using namespace std::placeholders; + +namespace transformer_engine { + +/*************************************************************************************************** + * Comm+GEMM Overlap Common Core + **************************************************************************************************/ + +bool ubuf_built_with_mpi() { +#ifdef NVTE_UB_WITH_MPI + return true; +#else + return false; +#endif +} + +CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, + int numnodes, int tp_size, ExtAllgatherOp allgather_handle, + ExtBarrierOp barrier_handle, int num_splits, int num_max_streams, + int comm_cga_size, int num_comm_sm, bool set_sm_margin, + bool use_ce, bool atomic_gemm) { + // Initialize userbuf communicator + if (!_comm_created) { + if (myrank == 0) { + printf("!!! [UB] Create Userbuffers Communicator\n"); + } +#ifdef NVTE_UB_WITH_MPI + create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1); +#else + create_communicator_grouped2(&_ub_comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, + allgather_handle, barrier_handle, 1, 1, tp_size, 1); +#endif + _comm_created = true; + } + _use_ce = static_cast(use_ce); + _num_comm_sm = num_comm_sm; + _cga_size = comm_cga_size; + + for (int i = 0; i < std::min(num_max_streams, num_splits); i++) { + cudaStream_t stream; + NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, -1)); + _stream_compute.push_back(std::move(stream)); + } + + _num_splits = num_splits; + _rank = _ub_comm->myrank; + _tp_size = tp_size; + _tp_id = _rank % _tp_size; + + // Set the number of SMs for GEMM with margin + int sm_count = transformer_engine::cuda::sm_count(); + _math_sms = (set_sm_margin) ? sm_count - num_comm_sm : sm_count; + _math_sms -= transformer_engine::getenv("NVTE_EXT_MARGIN_SM", 0); + + _atomic_gemm = atomic_gemm; + if (_atomic_gemm) { + void *counter_ptr; + size_t counter_bytes = _num_splits * 2 * sizeof(int32_t); + NVTE_CHECK_CUDA(cudaMalloc(&counter_ptr, counter_bytes)); + NVTE_CHECK_CUDA(cudaMemset(counter_ptr, 0, counter_bytes)); + NVTE_CHECK_CUDA(cudaMemset(counter_ptr, 1, counter_bytes / 2)); + _counter = TensorWrapper(counter_ptr, std::vector{static_cast(_num_splits * 2)}, + DType::kInt32); + } + // CUDA event creation + cudaEventCreateWithFlags(&_start_compute, 0); + cudaEventCreateWithFlags(&_stop_compute, 0); + cudaEventCreateWithFlags(&_start_comm, 0); + cudaEventCreateWithFlags(&_stop_comm, 0); +} + +CommOverlapCore::~CommOverlapCore() { + cudaEventDestroy(_stop_comm); + cudaEventDestroy(_start_comm); + cudaEventDestroy(_stop_compute); + cudaEventDestroy(_start_compute); + + if (_atomic_gemm) cudaFree(_counter.dptr()); + + for (size_t i = 0; i < _stream_compute.size(); i++) cudaStreamDestroy(_stream_compute[i]); + + if (_comm_created) { +#ifdef NVTE_UB_WITH_MPI + destroy_communicator_mpi(_ub_comm); +#else + destroy_communicator(_ub_comm); +#endif + _comm_created = false; + } +} + +/*************************************************************************************************** + * Comm+GEMM Overlap Base (Pipelined / Collective) + **************************************************************************************************/ + +CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType buffer_dtype, + int myrank, int numranks, int mylocal, int numlocal, int mynode, + int numnodes, int tp_size, ExtAllgatherOp allgather_handle, + ExtBarrierOp barrier_handle, int num_splits, int num_max_streams, + int comm_cga_size, int num_comm_sm, bool set_sm_margin, + bool atomic_gemm) + : CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, + allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size, + num_comm_sm, set_sm_margin, false, atomic_gemm) { + _rs_kernel_type = getenv("NVTE_RS_STRIDED_ATOMIC", 0); + NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3, + "Invalid choice for NVTE_RS_STRIDED_ATOMIC: Must be 0 (non-atomic), 1 (atomic) ", + "or 2 (multi-atomic)."); + + NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!"); + size_t buffer_bytes = buffer_shape[0] * buffer_shape[1] * typeToSize(buffer_dtype); + void *buffer_ptr; + _ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true); + if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg); + _ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype); + + NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&_stream_comm, cudaStreamNonBlocking, -1)); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_d2dcopy, 0)); +} + +CommOverlapBase::~CommOverlapBase() { + cudaEventDestroy(_start_d2dcopy); + cudaStreamDestroy(_stream_comm); +} + +/* +** Bulk GEMM + COMM +** This function assumes the communication input is pre-copied to _ubuf +*/ +void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, + CommOverlapType comm_type, TensorWrapper &rs_output, + cudaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + + // Catch up the default torch stream + NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0)); + + // Communication: AG and RS + int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size + if (comm_type == CommOverlapType::AG) { + allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm); + } else { + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + comm_elements *= 2; + assert(rs_output.numel() == _ubuf.numel() / _tp_size); + assert(rs_output.size(0) == _ubuf.size(0) / _tp_size); + assert(rs_output.element_size() == 2); + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf_scale_inv, _ub_reg, 0, + comm_elements, _ub_comm, _stream_comm); + } else { + reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm); + } + } + + assert(pre_gelu_out.numel() == 0); + nvte_cublas_gemm(A.data(), B.data(), D.data(), bias.data(), pre_gelu_out.data(), transa, transb, + grad, workspace.data(), accumulate, use_split_accumulator, _math_sms, + stream_main); + + _ub_comm->sms = ori_sms; + NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); +} // CommOverlapBase::bulk_overlap + +/* +** Split FPROP GEMM + ReduceScatter +*/ +void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + bool gemm_overlap, TensorWrapper &rs_output, + cudaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + // Get GEMM dimensions + size_t m = A.size(0); + size_t k = A.size(1); + size_t n = B.size(0); + size_t m_chunk = m / _num_splits; + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + + // Get input, output, and workspace data pointers + char *input_a_chunk_ptr = reinterpret_cast(A.dptr()); + char *output_buf_chunk_ptr = reinterpret_cast(_ubuf.dptr()); + char *workspace_ptr = reinterpret_cast(workspace.dptr()); + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + + // Reset atomic counters + int *counter_ptr = reinterpret_cast(_counter.dptr()); + reset_counters(counter_ptr, _num_splits, false, stream_main); + + // Catch up the default torch stream + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_compute, 0)); + + assert(pre_gelu_out.numel() == 0); + + auto output_d = TensorWrapper(_ubuf.dptr(), {n, m}, D.dtype(), D.amax(), D.scale(), nullptr); + auto workspace_chunk = + TensorWrapper(workspace.dptr(), std::vector{workspace_size_chunk}, workspace.dtype()); + nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(), + transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, _num_splits, 0, true, _counter.data(), + _stream_compute[0]); + + for (int i = 0; i < _num_splits; i++) { + if (_rs_kernel_type == 1) { + if (i == _num_splits - 1) { + _ub_comm->sms = UB_MAX_SM; + } + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reducescatter2_userbuff_strided_atomic_fp8( + rs_output_ptr, _ubuf_scale_inv, _ub_reg, i * m_chunk, m_chunk, n, m, m, _num_splits, + &counter_ptr[i], _ub_comm, _stream_comm);); + } else { + reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, + _num_splits, &counter_ptr[i], _ub_comm, + _stream_comm); + } + } else if (_rs_kernel_type == 2) { + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reducescatter2_userbuff_strided_multiatomic_fp8( + rs_output_ptr, _ubuf_scale_inv, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits, + counter_ptr, _ub_comm, _stream_comm);); + } else { + reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk, n, m, + _num_splits, counter_ptr, _ub_comm, + _stream_comm); + } + break; + } else { + consumer(counter_ptr, i, _stream_comm); + if (_ubuf.element_size() == 1) { + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reducescatter2_userbuff_stridedoutput_fp8(rs_output_ptr, _ubuf_scale_inv, + _ub_reg, i * m_chunk, m_chunk, n, m, + _ub_comm, _stream_comm);); + } else { + reducescatter2_userbuff_strided(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, + _ub_comm, _stream_comm); + } + } + + rs_output_ptr += m_chunk * rs_output.element_size(); + } + + _ub_comm->sms = ori_sms; + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[0])); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); +} // split_overlap_rs + +/* +** Split FPROP GEMM + ReduceScatter +*/ +void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + bool gemm_overlap, TensorWrapper &rs_output, + cudaStream_t stream_main) { + // Get GEMM dimensions + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + size_t m = A.size(0); + size_t k = A.size(1); + size_t n = B.size(0); + size_t m_chunk = m / _num_splits; + size_t input_a_chunk_size = m_chunk * k; + size_t output_chunk_size = n * m_chunk; + size_t bias_chunk_size = m_chunk; + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + + // Get input, output, and workspace data pointers + char *input_a_chunk_ptr = reinterpret_cast(A.dptr()); + char *output_buf_chunk_ptr = reinterpret_cast(_ubuf.dptr()); + char *bias_chunk_ptr = reinterpret_cast(bias.dptr()); + char *workspace_ptr = reinterpret_cast(workspace.dptr()); + + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + + // Catch up the default torch stream + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); + } + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_compute, 0)); + + assert(pre_gelu_out.numel() == 0); + + if (gemm_overlap) { + auto input_a_chunk = + TensorWrapper(A.dptr(), {m_chunk, k}, A.dtype(), nullptr, nullptr, A.scale_inv()); + auto output_chunk = + TensorWrapper(_ubuf.dptr(), {m, m_chunk}, D.dtype(), D.amax(), D.scale(), nullptr); + auto bias_chunk = + TensorWrapper(bias.dptr(), {m_chunk}, bias.dtype(), nullptr, nullptr, nullptr); + auto workspace_chunk = TensorWrapper( + workspace.dptr(), std::vector{workspace_size_chunk}, workspace.dtype()); + + nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(), + pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, _stream_compute[0]); + + for (int i = 1; i < _num_splits; i++) { + input_a_chunk_ptr += input_a_chunk_size * B.element_size(); + output_buf_chunk_ptr += output_chunk_size * D.element_size(); + if (bias_chunk_ptr != nullptr) { + bias_chunk_ptr += bias_chunk_size * bias.element_size(); + } + char *workspace_chunk_ptr = + workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; + + input_a_chunk = TensorWrapper(reinterpret_cast(input_a_chunk_ptr), {m_chunk, k}, + A.dtype(), nullptr, nullptr, A.scale_inv()); + output_chunk = TensorWrapper(reinterpret_cast(output_buf_chunk_ptr), {n, m_chunk}, + D.dtype(), D.amax(), D.scale(), nullptr); + bias_chunk = TensorWrapper(reinterpret_cast(bias_chunk_ptr), {m_chunk}, bias.dtype(), + nullptr, nullptr, nullptr); + workspace_chunk = TensorWrapper(reinterpret_cast(workspace_chunk_ptr), + std::vector{workspace_size_chunk}, workspace.dtype()); + + nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(), + pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), + accumulate, use_split_accumulator, _math_sms, + _stream_compute[i % _stream_compute.size()]); + + NVTE_CHECK_CUDA( + cudaEventRecord(_start_comm, _stream_compute[(i - 1) % _stream_compute.size()])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0)); + + // Communication chunk + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reducescatter2_userbuff_stridedoutput_fp8( + rs_output_ptr, _ubuf_scale_inv, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, m, + _ub_comm, _stream_comm);); + } else { + reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (i - 1) * output_chunk_size, + m_chunk, n, m, _ub_comm, _stream_comm); + } + + rs_output_ptr += m_chunk * rs_output.element_size(); + } + int last_compute_stream_id = + (_num_splits + _stream_compute.size() - 1) % _stream_compute.size(); + NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[last_compute_stream_id])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0)); + + // Last communication chunk with max SM + _ub_comm->sms = UB_MAX_SM; + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reducescatter2_userbuff_stridedoutput_fp8( + rs_output_ptr, _ubuf_scale_inv, _ub_reg, (_num_splits - 1) * output_chunk_size, + m_chunk, n, m, _ub_comm, _stream_comm);); + } else { + reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, + (_num_splits - 1) * output_chunk_size, m_chunk, n, m, + _ub_comm, _stream_comm); + } + } else { + for (int i = 0; i < _num_splits; i++) { + char *workspace_chunk_ptr = + workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; + + auto input_a_chunk = TensorWrapper(reinterpret_cast(input_a_chunk_ptr), {m_chunk, k}, + A.dtype(), nullptr, nullptr, A.scale_inv()); + auto output_chunk = TensorWrapper(reinterpret_cast(output_buf_chunk_ptr), + {n, m_chunk}, D.dtype(), D.amax(), D.scale(), nullptr); + auto bias_chunk = TensorWrapper(reinterpret_cast(bias_chunk_ptr), {m_chunk}, + bias.dtype(), nullptr, nullptr, nullptr); + auto workspace_chunk = + TensorWrapper(reinterpret_cast(workspace_chunk_ptr), + std::vector{workspace_size_chunk}, workspace.dtype()); + + nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(), + pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), + accumulate, use_split_accumulator, _math_sms, + _stream_compute[i % _stream_compute.size()]); + + NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[i % _stream_compute.size()])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0)); + + // Communication chunk. Uses MAX_SM at the last chunk + if (i == _num_splits - 1) { + _ub_comm->sms = UB_MAX_SM; + } + if (_ubuf.element_size() == 1) { + assert(_ubuf_scale_inv_initialized); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reducescatter2_userbuff_stridedoutput_fp8( + rs_output_ptr, _ubuf_scale_inv, _ub_reg, i * output_chunk_size, m_chunk, n, m, + _ub_comm, _stream_comm);); + } else { + reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size, + m_chunk, n, m, _ub_comm, _stream_comm); + } + + rs_output_ptr += m_chunk * rs_output.element_size(); + input_a_chunk_ptr += input_a_chunk_size * B.element_size(); + output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); + if (bias_chunk_ptr != nullptr) { + bias_chunk_ptr += bias_chunk_size * bias.element_size(); + } + } + } + + _ub_comm->sms = ori_sms; + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); + } + NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); +} // CommOverlapBase::split_overlap_rs + +/*************************************************************************************************** + * Comm+GEMM Overlap P2P Base (Ring-Exchange) + **************************************************************************************************/ + +CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, DType buffer_dtype, + int myrank, int numranks, int mylocal, int numlocal, + int mynode, int numnodes, int tp_size, + ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, + CommOverlapType comm_type, int num_max_streams, + int comm_cga_size, int num_comm_sm, bool set_sm_margin, + bool use_ce, bool atomic_gemm, bool aggregate) + : CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, + allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size, + num_comm_sm, set_sm_margin, use_ce, atomic_gemm) { + _is_p2p = true; + _is_reduce_scatter = comm_type == CommOverlapType::RS; + _aggregate = aggregate; + + // Create workspace tensor with userbuffer + NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!"); + size_t buffer_bytes = buffer_shape[0] * buffer_shape[1] * typeToSize(buffer_dtype); + int buffer_chunk_bytes = buffer_bytes / tp_size; + _num_ubuf_chunks = tp_size; + if (_is_reduce_scatter) { + // GEMM + RS overlap: Allocate `2 x tp_size - 1` buffers to hold recieved GEMM chunk + // outputs for reduction at the end of the pipelining. + buffer_bytes = buffer_bytes / tp_size * (tp_size * 2 - 1); + _num_ubuf_chunks = tp_size * 2 - 1; + } + + void *buffer_ptr; + _ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true); + if (_rank == 0) printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg); + _ubuf = TensorWrapper(buffer_ptr, {buffer_shape[0] / tp_size * _num_ubuf_chunks, buffer_shape[1]}, + buffer_dtype); + + // Create tensor chunks for easy management + char *ubuf_byte_ptr = reinterpret_cast(buffer_ptr); + for (int i = 0; i < _num_ubuf_chunks; i++) { + _ubufs.push_back(TensorWrapper(reinterpret_cast(ubuf_byte_ptr), + {buffer_shape[0] / tp_size, buffer_shape[1]}, buffer_dtype)); + ubuf_byte_ptr += buffer_chunk_bytes; + } + + _rank_round_tp = (_rank / _tp_size) * _tp_size; + _next_rank = (_tp_size + _rank + 1) % _tp_size + _rank_round_tp; + _prev_rank = (_tp_size + _rank + -1) % _tp_size + _rank_round_tp; + + _self_chunk_id = _tp_id; + if (_atomic_gemm && !_is_reduce_scatter) { + _use_multiatomic_ag = getenv("NVTE_AG_P2P_MULTI_ATOMIC"); + if (_use_multiatomic_ag) { + _use_ce = 0; + _ub_comm->push = 1; + if (_rank == 0) { + printf("!!userbuffers_sendrecv_multi_atomic_shuffle\n"); + } + } + _self_chunk_id = 0; + NVTE_CHECK_CUDA(cudaMemset(_counter.dptr(), 0, sizeof(int32_t))); + } + + NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&_stream_send, cudaStreamNonBlocking, -1)); + NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&_stream_recv, cudaStreamNonBlocking, -1)); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_send, 0)); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_recv, 0)); +} + +CommOverlapP2PBase::~CommOverlapP2PBase() { + cudaEventDestroy(_stop_recv); + cudaEventDestroy(_stop_send); + cudaStreamDestroy(_stream_recv); + cudaStreamDestroy(_stream_send); +} + +/* +** Split AllGather + AtomicGEMM using P2P communication +** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG +** outputs in each rank to be in the contiguous memory space after all ring exchange phases. +*/ +void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, + TensorWrapper &B_copy, cudaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + + // Get GEMM dimensions between TN and NN input layouts + const size_t m = (transa) ? A.size(0) : A.size(1); + const size_t n = _ubuf.size(0); + const size_t n_chunk = n / _tp_size; + assert(pre_gelu_out.numel() == 0); + + // Get communication and GEMM output chunk sizes + const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + + // Create an GEMM output buffer with N+1 chunks in a contiguous memory + void *D_buffer_ptr; + int D_chunk_bytes = n_chunk * m * D.element_size(); + NVTE_CHECK_CUDA(cudaMallocAsync(&D_buffer_ptr, (_tp_size + 1) * D_chunk_bytes, stream_main)); + auto D_buffer = TensorWrapper(D_buffer_ptr, D.shape(), D.dtype(), D.amax(), D.scale(), nullptr); + + // Reset atomic counters + int *counter_ptr = reinterpret_cast(_counter.dptr()); + reset_counters(counter_ptr, _tp_size, true, stream_main); + + // Catch up the default torch stream + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); + + auto input_b = TensorWrapper(_ubuf.dptr(), B.shape(), B.dtype(), nullptr, nullptr, B.scale_inv()); + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + auto workspace_chunk = + TensorWrapper(workspace.dptr(), std::vector{workspace_size_chunk}, workspace.dtype()); + + for (int i = 0; i < _tp_size - 1; i++) { + // Set the userbuffer id. Buffer under send is the input for the current + // GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to + // have the AG output in all ranks to be contiguous after the ring + // exchanges + int send_chunk_id = i; + int recv_chunk_id = i + 1; + int send_offset = comm_bytes * send_chunk_id; + int recv_offset = comm_bytes * recv_chunk_id; + + if (_use_multiatomic_ag) { + if (i == 0) { + _ub_comm->use_ce = 0; + userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes, + _ub_comm, _next_rank, _prev_rank, _tp_size, counter_ptr, + true, _stream_recv); + } + } else { + userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, _next_rank, + _stream_recv); + userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, _prev_rank, + _stream_recv); + producer(counter_ptr, recv_chunk_id, _stream_recv); + } + if (i == 0) { + nvte_cublas_atomic_gemm(A.data(), input_b.data(), D_buffer.data(), bias.data(), + pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), + accumulate, use_split_accumulator, _math_sms, 0, _tp_size, false, + _counter.data(), stream_main); + } + } + + // Store the input activation for backprop + if (B_copy.numel() > 0) { + assert(B_copy.numel() == _ubufs[_self_chunk_id].numel()); + assert(B_copy.element_size() == _ubufs[_self_chunk_id].element_size()); + NVTE_CHECK_CUDA( + cudaMemcpyAsync(B_copy.dptr(), _ubufs[_self_chunk_id].dptr(), + _ubufs[_self_chunk_id].numel() * _ubufs[_self_chunk_id].element_size(), + cudaMemcpyDeviceToDevice, _stream_send)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); + } + + // Copy the first GEMM output chunk to the end chunk position of D_buffer + char *src_ptr = reinterpret_cast(D_buffer.dptr()); + NVTE_CHECK_CUDA(cudaMemcpyAsync(src_ptr + (D.numel() * D.element_size()), src_ptr, D_chunk_bytes, + cudaMemcpyDeviceToDevice, stream_main)); + + // Return the last N rows of D_buffer + NVTE_CHECK_CUDA(cudaMemcpyAsync(D.dptr(), src_ptr + D_chunk_bytes, D.numel() * D.element_size(), + cudaMemcpyDeviceToDevice, stream_main)); + + // Clean up buffer allocation + NVTE_CHECK_CUDA(cudaFreeAsync(D_buffer_ptr, stream_main)); + + _ub_comm->sms = ori_sms; +} // CommOverlapP2PBase::atomic_gemm_overlap_ag + +/* +** Split AllGather + GEMM using P2P communication +** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG +** outputs in each rank to be in the contiguous memory space after all ring exchange phases. +*/ +void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + TensorWrapper &B_copy, cudaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + // Get GEMM dimensions between TN and NN input layouts + const size_t m = (transa) ? A.size(0) : A.size(1); + const size_t k = (transa) ? A.size(1) : A.size(0); + const size_t n_chunk = _ubufs[0].size(0); + + // Get communication and GEMM output chunk sizes + const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + const bool do_gelu = pre_gelu_out.numel() > 0; + const int output_chunk_bytes = (n_chunk * m) * D.element_size(); + const int aux_chunk_bytes = do_gelu ? (n_chunk * m) * pre_gelu_out.element_size() : 0; + + // Get output and workspace data pointers + char *output_ptr = reinterpret_cast(D.dptr()); + char *pre_gelu_out_ptr = reinterpret_cast(pre_gelu_out.dptr()); + char *workspace_ptr = reinterpret_cast(workspace.dptr()); + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); + } + if (_aggregate) { + const int num_steps = _tp_size / 2; + char *input_b_ptr = reinterpret_cast(_ubuf.dptr()); + + // Initial 1X input chunk exchange between neighboring peers + int send_chunk_id = _tp_id; + int recv_chunk_id = (_tp_id % 2 == 0) ? _tp_id + 1 : _tp_id - 1; + int send_offset = comm_bytes * send_chunk_id; + int recv_offset = comm_bytes * recv_chunk_id; + int peer_rank = (_tp_id % 2 == 0) ? _next_rank : _prev_rank; + userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, peer_rank, + _stream_send); + userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, peer_rank, + _stream_recv); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _stop_recv, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _stop_recv, 0)); + + int local_rank_round2 = (_tp_id % 2 == 0) ? _tp_id : _tp_id - 1; + const int next_rank = (_tp_size + _tp_id + 2) % _tp_size + _rank_round_tp; + const int prev_rank = (_tp_size + _tp_id - 2) % _tp_size + _rank_round_tp; + + // Ring exchange of 2X inputs chunks + for (int i = 0; i < num_steps; i++) { + send_chunk_id = (_tp_size + local_rank_round2 - i * 2) % _tp_size; + recv_chunk_id = (_tp_size + local_rank_round2 - i * 2 - 2) % _tp_size; + send_offset = comm_bytes * send_chunk_id; + recv_offset = comm_bytes * recv_chunk_id; + + // GEMM + char *input_b_chunk_ptr = input_b_ptr + send_offset; + auto input_b_chunk = + TensorWrapper(reinterpret_cast(input_b_chunk_ptr), {n_chunk * 2, k}, B.dtype(), + nullptr, nullptr, B.scale_inv()); + + char *output_chunk_ptr = output_ptr + (send_chunk_id * output_chunk_bytes); + auto output_chunk = TensorWrapper(reinterpret_cast(output_chunk_ptr), + {n_chunk * 2, m}, D.dtype(), D.amax(), D.scale(), nullptr); + + char *aux_chunk_ptr = + (do_gelu) ? pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes) : nullptr; + auto aux_chunk_shape = + (do_gelu) ? std::vector{n_chunk * 2, m} : std::vector{0}; + auto aux_chunk = TensorWrapper(reinterpret_cast(aux_chunk_ptr), aux_chunk_shape, + pre_gelu_out.dtype()); + + char *workspace_chunk_ptr = + workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; + auto workspace_chunk = + TensorWrapper(reinterpret_cast(workspace_chunk_ptr), + std::vector{workspace_size_chunk}, workspace.dtype()); + + nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), + aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, + _stream_compute[i % _stream_compute.size()]); + + if (i < num_steps - 1) { + // P2P communication + userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes * 2, _ub_comm, + next_rank, _stream_send); + userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm, + prev_rank, _stream_recv); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _stop_recv, 0)); + NVTE_CHECK_CUDA( + cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); + } else if (B_copy.numel() > 0) { + assert(B_copy.numel() == _ubufs[_tp_id].numel()); + assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); + NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), + _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), + cudaMemcpyDeviceToDevice, _stream_send)); + } + } + } else { + for (int i = 0; i < _tp_size; i++) { + // Set the userbuffer id. Buffer under send is the input for the current + // GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to + // have the AG output in all ranks to be contiguous after the ring + // exchanges + int send_chunk_id = (_tp_size + _tp_id - i) % _tp_size; + int recv_chunk_id = (_tp_size + _tp_id - i - 1) % _tp_size; + int send_offset = comm_bytes * send_chunk_id; + int recv_offset = comm_bytes * recv_chunk_id; + + // GEMM + auto input_b_chunk = TensorWrapper(_ubufs[send_chunk_id].dptr(), {n_chunk, k}, B.dtype(), + nullptr, nullptr, B.scale_inv()); + + char *output_chunk_ptr = output_ptr + (send_chunk_id * output_chunk_bytes); + auto output_chunk = TensorWrapper(reinterpret_cast(output_chunk_ptr), {n_chunk, m}, + D.dtype(), D.amax(), D.scale(), nullptr); + + char *aux_chunk_ptr = + (do_gelu) ? pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes) : nullptr; + auto aux_chunk_shape = (do_gelu) ? std::vector{n_chunk, m} : std::vector{0}; + auto aux_chunk = TensorWrapper(reinterpret_cast(aux_chunk_ptr), aux_chunk_shape, + pre_gelu_out.dtype()); + + char *workspace_chunk_ptr = + workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; + auto workspace_chunk = + TensorWrapper(reinterpret_cast(workspace_chunk_ptr), + std::vector{workspace_size_chunk}, workspace.dtype()); + + nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), + aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, + _stream_compute[i % _stream_compute.size()]); + + if (i < _tp_size - 1) { + // P2P communication + userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, + _next_rank, _stream_send); + userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, + _prev_rank, _stream_recv); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _stop_recv, 0)); + NVTE_CHECK_CUDA( + cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); + } else if (B_copy.numel() > 0) { + assert(B_copy.numel() == _ubufs[_tp_id].numel()); + assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); + NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), + _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), + cudaMemcpyDeviceToDevice, _stream_send)); + } + } + } + + _ub_comm->sms = ori_sms; + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); + } + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); +} // CommOverlapP2PBase::split_overlap_ag + +/* +** Split ReduceScatter + GEMM using P2P communication +*/ +void CommOverlapP2PBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, + TensorWrapper &rs_output, + cudaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + + // Get communication and GEMM input chunk sizes + const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + + // Reset counters + int *counter_ptr = reinterpret_cast(_counter.dptr()); + reset_counters(counter_ptr, _tp_size, false, stream_main); + + // Catch up the main stream + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); + + // Atomic GEMM + // Process GEMM chunks in the order that AG+GEMM places the output chunks. + auto output_d = TensorWrapper(_ubuf.dptr(), D.shape(), D.dtype(), D.amax(), D.scale(), nullptr); + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + auto workspace_chunk = + TensorWrapper(workspace.data(), std::vector{workspace_size_chunk}, workspace.dtype()); + nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(), + transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, 0, _tp_size, true, _counter.data(), + stream_main); + + // P2P communication chunk + for (int i = 1; i < _tp_size; i++) { + int send_chunk_id = i - 1; + int recv_chunk_id = send_chunk_id + _tp_size; + int send_offset = comm_bytes * send_chunk_id; + int recv_offset = comm_bytes * recv_chunk_id; + int send_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; + int recv_rank = (_tp_id + i) % _tp_size + _rank_round_tp; + + consumer(counter_ptr, send_chunk_id, _stream_recv); + userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, send_rank, + _stream_recv); + userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, recv_rank, + _stream_recv); + } + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); + + // Reduce GEMM output chunks + char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].dptr()); + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { + assert(_ubuf_scale_inv_initialized); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, _ubuf_scale_inv, _tp_size, + _ubufs[0].numel(), stream_main);); + } else { + reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main); + } + _ub_comm->sms = ori_sms; +} + +/* +** Split ReduceScatter + GEMM using P2P communication +*/ +void CommOverlapP2PBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + TensorWrapper &rs_output, cudaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + size_t k = A.size(1); + size_t n = B.size(0); + + // Get communication and GEMM input chunk sizes + size_t n_chunk = n / _tp_size; + const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + const int input_b_chunk_bytes = n_chunk * k * B.element_size(); + + // Get input and workspace data pointers + char *input_b_ptr = reinterpret_cast(B.dptr()); + char *workspace_ptr = reinterpret_cast(workspace.dptr()); + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + + // Catch up the main stream + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); + } + + // GEMM and send/recv chunks + for (int i = 0; i < _tp_size; i++) { + // GEMM chunk + int input_b_chunk_id = (_tp_id + i + 1) % _tp_size; + char *input_b_chunk_ptr = input_b_ptr + (input_b_chunk_id * input_b_chunk_bytes); + + auto input_b_chunk = TensorWrapper(reinterpret_cast(input_b_chunk_ptr), {n_chunk, k}, + B.dtype(), nullptr, nullptr, B.scale_inv()); + + auto output_chunk = + TensorWrapper(_ubufs[i].dptr(), _ubufs[i].shape(), D.dtype(), D.amax(), D.scale(), nullptr); + + char *workspace_chunk_ptr = workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; + auto workspace_chunk = + TensorWrapper(reinterpret_cast(workspace_chunk_ptr), + std::vector{workspace_size_chunk}, workspace.dtype()); + + nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), + pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, _stream_compute[i % _stream_compute.size()]); + + if (i > 0) { + // P2P communication chunk + int send_offset = comm_bytes * (i - 1); + int recv_offset = comm_bytes * (i - 1 + _tp_size); + int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp; + int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; + NVTE_CHECK_CUDA( + cudaEventRecord(_start_comm, _stream_compute[(i - 1) % _stream_compute.size()])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_comm, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_comm, 0)); + userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, send_rank, + _stream_send); + userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, recv_rank, + _stream_recv); + } + } + + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); + } + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); + + // Reduce GEMM output chunks + char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].dptr()); + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { + assert(_ubuf_scale_inv_initialized); + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, _ubuf_scale_inv, _tp_size, + _ubufs[0].numel(), stream_main);); + } else { + reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main); + } + + _ub_comm->sms = ori_sms; +} + +} // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/userbuffers/ipcsocket.cc b/transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.cc similarity index 100% rename from transformer_engine/pytorch/csrc/userbuffers/ipcsocket.cc rename to transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.cc diff --git a/transformer_engine/pytorch/csrc/userbuffers/ipcsocket.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.h similarity index 100% rename from transformer_engine/pytorch/csrc/userbuffers/ipcsocket.h rename to transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.h diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp similarity index 92% rename from transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp rename to transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp index e2628f6a31..6f3eef3d28 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -20,7 +20,9 @@ #include #include "common/util/cuda_driver.h" +#include "common/util/cuda_runtime.h" #include "common/util/logging.h" +#include "common/util/system.h" #include "ipcsocket.h" #include "userbuffers.h" @@ -44,31 +46,19 @@ static MPI_Comm EXT_COMM_INTER; } while (false) void ub_mpi_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes, - ExtComm group) { - // UB_MPI_CHECK(MPI_Allgather(localdata, localbytes, MPI_BYTE, - // globaldata, globalbytes, MPI_BYTE, - // static_cast(group))); - MPI_Comm comm = static_cast(group); + ExtComm comm) { int numranks; UB_MPI_CHECK(MPI_Comm_size(comm, &numranks)); assert(globalbytes == numranks * localbytes); - - int myrank; - UB_MPI_CHECK(MPI_Comm_rank(comm, &myrank)); - char *globaltarget = reinterpret_cast(globaldata) + (myrank * localbytes); - memcpy(globaltarget, localdata, localbytes); - - for (int n = 0; n < numranks; n++) { - globaltarget = reinterpret_cast(globaldata) + (n * localbytes); - UB_MPI_CHECK(MPI_Bcast(globaltarget, localbytes, MPI_BYTE, n, comm)); - } + UB_MPI_CHECK( + MPI_Allgather(localdata, localbytes, MPI_BYTE, globaldata, localbytes, MPI_BYTE, comm)); } -void ub_mpi_barrier(ExtComm group) { UB_MPI_CHECK(MPI_Barrier(static_cast(group))); } +void ub_mpi_barrier(ExtComm comm) { UB_MPI_CHECK(MPI_Barrier(comm)); } #else -static char EXT_COMM_WORLD[] = "world"; -static char EXT_COMM_INTRA[] = "intra"; -static char EXT_COMM_INTER[] = "inter"; +#define EXT_COMM_WORLD "world" +#define EXT_COMM_INTRA "intra" +#define EXT_COMM_INTER "inter" #endif #define MULTICAST_GB_TOTAL 512 @@ -106,11 +96,10 @@ int pipe_rank(communicator *comm, int step) { return newnode * numlocal + newlocal; } -int create_communicator_grouped2( - communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, - int numnodes, std::function ext_allgather, - std::function ext_barrier, int pipegpus, int pipenodes, int tensorgpus, - int tensornodes) { +int create_communicator_grouped2(communicator **comm, int myrank, int numranks, int mylocal, + int numlocal, int mynode, int numnodes, + ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier, + int pipegpus, int pipenodes, int tensorgpus, int tensornodes) { *comm = new communicator(); (*comm)->comm_world = EXT_COMM_WORLD; @@ -214,8 +203,11 @@ int create_communicator_grouped2( (*comm)->asyncblocks = 16; #define NBUF 2 - if ((*comm)->sm_arch >= 9 && (*comm)->ar2_nvsize > 1 && - !getenv("UB_SKIPMC")) { // multicast init only for TP ops (____2 operations) + +#if CUDART_VERSION >= 12010 + if (!transformer_engine::getenv("UB_SKIPMC") && + transformer_engine::cuda::supports_multicast() && (*comm)->ar2_nvsize > 1) { + // multicast init only for TP ops (____2 operations) size_t mc_maxsize = MULTICAST_GB_TOTAL * (1ull << 30); (*comm)->mc_offset = 0; (*comm)->use_mc = 1; @@ -291,20 +283,20 @@ int create_communicator_grouped2( (*comm)->_barrier((*comm)->comm_world); if (!(*comm)->myrank) printf("MC initialized succesfully, window size = %ld\n", mc_maxsize); } else { +#endif if (!(*comm)->myrank) printf("MC NOT initialized and used\n"); (*comm)->mc_maxsize = 0; (*comm)->mc_offset = 0; (*comm)->use_mc = 0; +#if CUDART_VERSION >= 12010 } +#endif #define LOCALSIZE 4 * (NVTE_REG0_OFFSET(*comm) + NVTE_REG0_FLAGS + NVTE_REG0_COMMBUFFER * NBUF) // peer pointers + op flags + comm buffer - NVTE_CHECK_CUDA( - cudaMalloc(&(*comm)->gpu_ptrs, LOCALSIZE)); // flags and pointers, no block data yet - NVTE_CHECK_CUDA(cudaMemset((*comm)->gpu_ptrs, 0, LOCALSIZE)); NVTE_CHECK_CUDA(cudaDeviceSynchronize()); - register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE, *comm, false); + register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE, *comm, true); NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->send_id, (*comm)->nranks * sizeof(int))); NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->recv_id, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int))); NVTE_CHECK_CUDA(cudaMemset((*comm)->send_id, 0, (*comm)->nranks * sizeof(int))); @@ -346,18 +338,17 @@ int create_communicator_grouped2( return 0; } -int create_communicator_grouped( - communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, - int numnodes, std::function ext_allgather, - std::function ext_barrier, int pipegpus, int pipenodes) { +int create_communicator_grouped(communicator **comm, int myrank, int numranks, int mylocal, + int numlocal, int mynode, int numnodes, + ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier, + int pipegpus, int pipenodes) { return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, ext_allgather, ext_barrier, pipegpus, pipenodes, 1, 1); } int create_communicator(communicator **comm, int myrank, int numranks, int mylocal, int numlocal, - int mynode, int numnodes, - std::function ext_allgather, - std::function ext_barrier) { + int mynode, int numnodes, ExtAllgatherOp ext_allgather, + ExtBarrierOp ext_barrier) { return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, ext_allgather, ext_barrier, 1, 1, 1, 1); } @@ -428,7 +419,7 @@ int create_communicator_mpi(communicator **comm) { void destroy_communicator(communicator *comm) { for (int hndl = 0; hndl < comm->free_region; hndl++) { - if (hndl > 0 && comm->use_mc && comm->mem_dealloc[hndl]) { + if (comm->use_mc && comm->mem_dealloc[hndl]) { for (int rank = 0; rank < comm->nvsize; rank++) { if (rank == comm->nvrank) { NVTE_CALL_CHECK_CUDA_DRIVER(cuMemRelease, comm->uchandles[hndl][rank]); @@ -479,6 +470,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * comm->memflags[hndl] = 0; comm->mem_dealloc[hndl] = alloc; +#if CUDART_VERSION >= 12010 if (comm->use_mc && alloc) { int nranks = comm->nvsize; // total GPUs in NVLINK domain int myrank = comm->nvrank; @@ -594,6 +586,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * } } else { +#endif if (alloc) { NVTE_CHECK_CUDA(cudaMalloc(gpubuff, bytes)); NVTE_CHECK_CUDA(cudaMemset(*gpubuff, 0, bytes)); @@ -624,7 +617,9 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * NVTE_CHECK_CUDA(cudaDeviceSynchronize()); free(tmp); +#if CUDART_VERSION >= 12010 } +#endif comm->mem_size[hndl] = aligned_size; comm->mem_ptr[hndl] = *gpubuff; diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu similarity index 98% rename from transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu rename to transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index 0cd2a0253b..26843d8107 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -392,14 +392,14 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id; } // fp16 reduce-scatter kernel (out of place) -#if __CUDA_ARCH__ >= 900 +#if __CUDA_ARCH__ >= 900 && CUDART_VERSION >= 12010 // All MC kernels here template __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc(const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep, const int lineoffset, const int numlines, void **commbuff, const int handleridx, - float4 *mc_ptr) { + float4 *mc_ptr, const uint64_t ub_timeout) { int *flagptr, physgpu, targetgpu, *myptr; int *reduceidptr, reduce_id; @@ -417,7 +417,7 @@ __global__ void __launch_bounds__(MAX_THREADS) volatile int *flag = (volatile int *)&(myptr[targetgpu]); clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { - if (clock64() - s > TIMEOUT) { + if (clock64() - s > ub_timeout) { UB_PRINT("Reduce-scatter: SM %d [%d]: expecting %d got %d", blockIdx.x, threadIdx.x, reduce_id, *flag); break; @@ -484,7 +484,7 @@ __global__ void __launch_bounds__(MAX_THREADS) volatile int *flag = (volatile int *)&myptr[targetgpu]; clock_t s = clock64(); while (CHECK_IDS(*flag, reduce_id)) { - if (clock64() - s > 2ull * TIMEOUT) { + if (clock64() - s > 2ull * ub_timeout) { UB_PRINT("Allgather: SM %d [%d]: expecting %d got %d", blockIdx.x, threadIdx.x, reduce_id, *flag); break; @@ -741,7 +741,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc(const int op, const int flagoffset, const int firstrank, const int myrank, const int gpustep, const int lineoffset, const int numlines, void **commbuff, const int handleridx, - float4 *mc_ptr) {} + float4 *mc_ptr, const uint64_t ub_timeout) {} template __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc_rs_oop(const int op, const int flagoffset, @@ -2496,6 +2496,18 @@ static __global__ void consumer_batch_kernel(void *atomic_ptr, int first_chunk_i } } +// reset counters kernel +static __global__ void reset_counters_kernel(void *atomic_ptr, int num_chunks, bool allgather) { + if (blockIdx.x == 0 && threadIdx.x == 0) { +#pragma unroll + for (int i = 0; i < num_chunks; i++) { + ((unsigned int *)atomic_ptr)[i] = 1; + ((unsigned int *)atomic_ptr)[i + num_chunks] = 0; + } + if (allgather) ((unsigned int *)atomic_ptr)[0] = 0; + } +} + void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream) { dim3 block(1); dim3 grid(1); @@ -2514,6 +2526,12 @@ void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStr consumer_batch_kernel<<>>(atomic_ptr, first_chunk_i, num_chunks); } +void reset_counters(void *atomic_ptr, int num_chunks, bool allgather, cudaStream_t stream) { + dim3 block(1); + dim3 grid(1); + reset_counters_kernel<<>>(atomic_ptr, num_chunks, allgather); +} + template __global__ void __launch_bounds__(MAX_THREADS / 4) reduce_fp8_in_bf16_out_cuda(void *inputs, void *output, const float *scale, @@ -2546,3 +2564,24 @@ template void reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(void *inputs, void *output, template void reduce_fp8_in_bf16_out<__nv_fp8_e5m2>(void *inputs, void *output, float *scale, int num_inputs, int input_size, cudaStream_t stream); + +__global__ void __launch_bounds__(MAX_THREADS / 4) + reduce_bf16_cuda(void *inputs, void *output, const int num_inputs, const int input_size) { + const size_t tid = threadIdx.x + blockDim.x * blockIdx.x; + half *inputs_half = reinterpret_cast(inputs); + float accum_buf = static_cast(inputs_half[tid]); +#pragma unroll + for (int i = 1; i < num_inputs; i++) { + accum_buf += static_cast(inputs_half[tid + input_size * i]); + } + half *output_half = reinterpret_cast(output); + output_half[tid] = (half)accum_buf; +} + +void reduce_bf16(void *inputs, void *output, int num_inputs, int input_size, cudaStream_t stream) { + size_t num_threads = MAX_THREADS / 4; + size_t num_blocks = (input_size + num_threads - 1) / num_threads; + dim3 block(num_threads); + dim3 grid(num_blocks); + reduce_bf16_cuda<<>>(inputs, output, num_inputs, input_size); +} diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h similarity index 90% rename from transformer_engine/pytorch/csrc/userbuffers/userbuffers.h rename to transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h index 371932f446..57e68afce0 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h @@ -19,11 +19,14 @@ #ifdef NVTE_UB_WITH_MPI #include -typedef MPI_Comm ExtComm; +#define ExtComm MPI_Comm #else -typedef char *ExtComm; +#define ExtComm const char * #endif +using ExtAllgatherOp = std::function; +using ExtBarrierOp = std::function; + #define NVTE_MAX_REGIONS 16 #define NVTE_MAX_SMS 32 #define NVTE_MAX_OPS 32 @@ -142,12 +145,12 @@ struct communicator { volatile int tail; // Abstract communication callbacks to support external bootstrapping (e.g. DL frameworks) - std::function _allgather; - std::function _barrier; + ExtAllgatherOp _allgather; + ExtBarrierOp _barrier; - ExtComm comm_world, - comm_inter, // reduction group communicator (subset of the nodes) along GPU rail - comm_intra; // full intranode (all ndev GPUS) + ExtComm comm_world; + ExtComm comm_inter; // reduction group communicator (subset of the nodes) along GPU rail + ExtComm comm_intra; // full intranode (all ndev GPUS) #ifdef NVTE_UB_WITH_MPI MPI_Request mpihndl[NVTE_MAX_SHARP]; #endif @@ -161,23 +164,22 @@ typedef struct communicator communicator; void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream); void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream); void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStream_t stream); +void reset_counters(void *atomic_ptr, int num_chunks, bool allgather, cudaStream_t stream); /* creates communicator, allocates all internal buffers if necessary */ -int create_communicator_grouped2( - communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, - int numnodes, std::function ext_allgather, - std::function ext_barrier, int pipegpus, int pipenodes, int tensorgpus, - int tensornodes); +int create_communicator_grouped2(communicator **comm, int myrank, int numranks, int mylocal, + int numlocal, int mynode, int numnodes, + ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier, + int pipegpus, int pipenodes, int tensorgpus, int tensornodes); -int create_communicator_grouped( - communicator **comm, int myrank, int numranks, int mylocal, int numlocal, int mynode, - int numnodes, std::function ext_allgather, - std::function ext_barrier, int pipegpus, int pipenodes); +int create_communicator_grouped(communicator **comm, int myrank, int numranks, int mylocal, + int numlocal, int mynode, int numnodes, + ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier, + int pipegpus, int pipenodes); int create_communicator(communicator **comm, int myrank, int numranks, int mylocal, int numlocal, - int mynode, int numnodes, - std::function ext_allgather, - std::function ext_barrier); + int mynode, int numnodes, ExtAllgatherOp ext_allgather, + ExtBarrierOp ext_barrier); int create_communicator_grouped2_mpi(communicator **comm, int pipegpus, int pipenodes, int tensorgpus, int tensornodes); @@ -314,4 +316,6 @@ template void reduce_fp8_in_bf16_out(void *input, void *output, float *scale, int num_inputs, int input_size, cudaStream_t stream); +void reduce_bf16(void *input, void *output, int num_inputs, int input_size, cudaStream_t stream); + #endif // TRANSFORMER_ENGINE_USERBUFFERS_H_ diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 7e72e1b031..8830c8875d 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -56,6 +56,7 @@ constexpr T DIVUP(const T &x, const T &y) { using byte = uint8_t; using int32 = int32_t; +using int64 = int64_t; using fp32 = float; using fp16 = half; using bf16 = nv_bfloat16; @@ -73,6 +74,7 @@ constexpr inline const char *type_name() noexcept; } TRANSFORMER_ENGINE_TYPE_NAME(uint8_t) TRANSFORMER_ENGINE_TYPE_NAME(int32_t) +TRANSFORMER_ENGINE_TYPE_NAME(int64_t) TRANSFORMER_ENGINE_TYPE_NAME(float) TRANSFORMER_ENGINE_TYPE_NAME(half) TRANSFORMER_ENGINE_TYPE_NAME(nv_bfloat16) @@ -84,7 +86,7 @@ TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e5m2) template struct TypeInfo { - using types = std::tuple; + using types = std::tuple; template struct Helper { @@ -121,7 +123,11 @@ struct TypeInfo { { __VA_ARGS__ } \ } break; \ case DType::kInt32: { \ - using type = float; \ + using type = int32_t; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kInt64: { \ + using type = int64_t; \ { __VA_ARGS__ } \ } break; \ case DType::kFloat32: { \ @@ -246,6 +252,14 @@ inline int log2_ceil(int value) { return log2_value; } +template +inline size_t alignTo(size_t x) { + size_t r = x % B; + if (r == 0) return x; + + return x + B - r; +} + template struct is_fp8 : std::false_type {}; diff --git a/transformer_engine/common/cudnn_utils.cpp b/transformer_engine/common/cudnn_utils.cpp new file mode 100644 index 0000000000..35e2d11799 --- /dev/null +++ b/transformer_engine/common/cudnn_utils.cpp @@ -0,0 +1,73 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "cudnn_utils.h" + +#include "./util/logging.h" +#include "transformer_engine/cudnn.h" + +namespace transformer_engine { + +// get cuDNN data type +cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t) { + using namespace transformer_engine; + switch (t) { + case DType::kInt32: + return CUDNN_DATA_INT32; + case DType::kInt64: + return CUDNN_DATA_INT64; + case DType::kFloat16: + return CUDNN_DATA_HALF; + case DType::kFloat32: + return CUDNN_DATA_FLOAT; + case DType::kBFloat16: + return CUDNN_DATA_BFLOAT16; + case DType::kFloat8E4M3: + return CUDNN_DATA_FP8_E4M3; + case DType::kFloat8E5M2: + return CUDNN_DATA_FP8_E5M2; + default: + NVTE_ERROR("Invalid cuDNN data type. \n"); + } +} + +// get cuDNN data type +cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t) { + using namespace transformer_engine; + switch (t) { + case DType::kInt32: + return cudnn_frontend::DataType_t::INT32; + case DType::kInt64: + return cudnn_frontend::DataType_t::INT64; + case DType::kFloat16: + return cudnn_frontend::DataType_t::HALF; + case DType::kFloat32: + return cudnn_frontend::DataType_t::FLOAT; + case DType::kBFloat16: + return cudnn_frontend::DataType_t::BFLOAT16; + case DType::kFloat8E4M3: + return cudnn_frontend::DataType_t::FP8_E4M3; + case DType::kFloat8E5M2: + return cudnn_frontend::DataType_t::FP8_E5M2; + default: + NVTE_ERROR("Invalid cuDNN data type. \n"); + } +} + +void nvte_cudnn_handle_init() { + auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); +} + +} // namespace transformer_engine + +namespace cudnn_frontend { + +// This is needed to define the symbol `cudnn_dlhandle` +// When using the flag NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING +// to enable dynamic loading. +void *cudnn_dlhandle = nullptr; + +} // namespace cudnn_frontend diff --git a/transformer_engine/common/cudnn_utils.h b/transformer_engine/common/cudnn_utils.h new file mode 100644 index 0000000000..d2827b637a --- /dev/null +++ b/transformer_engine/common/cudnn_utils.h @@ -0,0 +1,46 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_CUDNN_UTILS_H_ +#define TRANSFORMER_ENGINE_CUDNN_UTILS_H_ + +#include +#include +#include + +#include +#include + +#include "transformer_engine/transformer_engine.h" + +namespace transformer_engine { + +cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t); + +cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t); + +class cudnnExecutionPlanManager { + public: + static cudnnExecutionPlanManager &Instance() { + static thread_local cudnnExecutionPlanManager instance; + return instance; + } + + cudnnHandle_t GetCudnnHandle() { + static thread_local std::once_flag flag; + std::call_once(flag, [&] { cudnnCreate(&handle_); }); + return handle_; + } + + ~cudnnExecutionPlanManager() {} + + private: + cudnnHandle_t handle_ = nullptr; +}; + +} // namespace transformer_engine + +#endif diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 70f1fa409f..9cde765401 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -7,6 +7,7 @@ #include "transformer_engine/fused_attn.h" #include "../common.h" +#include "../cudnn_utils.h" #include "../util/cuda_runtime.h" #include "../util/system.h" #include "fused_attn_f16_arbitrary_seqlen.h" @@ -80,7 +81,18 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( const int sm_arch_ = cuda::sm_arch(device_id); NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type."); NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); auto cudnn_runtime_version = cudnnGetVersion(); + + // For ragged offsets we only support 32-bit prior to cuDNN 9.5 + // Only used when THD format is requested. + const bool requires_64bit_ragged_offset = + (qkv_format == NVTE_THD && fused_attn::get_ragged_offset_dtype( + layout_group, num_attn_heads, num_gqa_groups, max_seqlen_q, + max_seqlen_kv, head_dim_qk, head_dim_v) == DType::kInt64); + const bool supported_ragged_offset_size = + (!requires_64bit_ragged_offset || cudnn_runtime_version >= 90500); + if (((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2)) && (sm_arch_ >= 90) && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) && (((cudnn_runtime_version >= 8900) && (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) && @@ -91,7 +103,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) && ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || - (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))))) { + (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)))) && + !requires_64bit_ragged_offset) { if (cudnn_runtime_version >= 8900) { backend = NVTE_Fused_Attn_Backend::NVTE_FP8; } else { @@ -118,7 +131,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) || (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) || (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)) && - ((window_size_left == -1) && (window_size_right == -1 || window_size_right == 0))) { + ((window_size_left == -1) && (window_size_right == -1 || window_size_right == 0)) && + !requires_64bit_ragged_offset) { flag_m512 = true; } if ( // architecture @@ -168,10 +182,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) && bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && // qkv format - ((qkv_format == NVTE_QKV_Format::NVTE_SBHD) || - (sm_arch_ >= 90 && cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups && - qkv_format == NVTE_QKV_Format::NVTE_THD) || - (qkv_format == NVTE_QKV_Format::NVTE_BSHD)) && + ((qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) || + (qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 && + ((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) || + (cudnn_runtime_version >= 90600)))) && // sliding window ((cudnn_runtime_version < 90200 && window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || @@ -183,7 +197,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( max_seqlen_q == max_seqlen_kv)) && dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && (qkv_format == NVTE_QKV_Format::NVTE_BSHD || - qkv_format == NVTE_QKV_Format::NVTE_SBHD)))))) { + qkv_format == NVTE_QKV_Format::NVTE_SBHD))))) && + // check 64-bit ragged offset support + (supported_ragged_offset_size)) { flag_arb = true; } if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { @@ -257,6 +273,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!"); } size_t d = input_QKV->data.shape[ndim - 1]; + size_t t = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + t = input_QKV->data.shape[0]; + } auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); @@ -277,7 +298,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) fused_attn_arbitrary_seqlen_fwd_qkvpacked( - b, h, max_seqlen, d, is_training, attn_scale, dropout, qkv_layout, bias_type, + b, h, max_seqlen, d, t, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_QKV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); @@ -334,6 +355,11 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!"); } size_t d = input_QKV->data.shape[ndim - 1]; + size_t t = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + t = input_QKV->data.shape[0]; + } auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); @@ -362,7 +388,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); } fused_attn_arbitrary_seqlen_bwd_qkvpacked( - b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, + b, h, max_seqlen, d, t, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, input_QKV, input_O, input_dO, input_Bias, output_S, output_dQKV, output_dBias, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); @@ -427,6 +453,13 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const } else { NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!"); } + size_t t_q = 0; + size_t t_kv = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + t_q = input_Q->data.shape[0]; + t_kv = input_KV->data.shape[0]; + } auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); @@ -448,9 +481,9 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8903) fused_attn_arbitrary_seqlen_fwd_kvpacked( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, input_KV, - input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, is_training, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, + input_KV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else @@ -511,6 +544,13 @@ void nvte_fused_attn_bwd_kvpacked( } else { NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!"); } + size_t t_q = 0; + size_t t_kv = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + t_q = input_Q->data.shape[0]; + t_kv = input_KV->data.shape[0]; + } auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); @@ -541,9 +581,9 @@ void nvte_fused_attn_bwd_kvpacked( input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); } fused_attn_arbitrary_seqlen_bwd_kvpacked( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, window_size_left, window_size_right, deterministic, input_Q, input_KV, - input_O, input_dO, input_Bias, output_S, output_dQ, output_dKV, output_dBias, + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, input_Q, + input_KV, input_O, input_dO, input_Bias, output_S, output_dQ, output_dKV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else @@ -601,6 +641,13 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t h_kv = input_K->data.shape[ndim - 2]; size_t d_qk = input_Q->data.shape[ndim - 1]; size_t d_v = input_V->data.shape[ndim - 1]; + size_t t_q = 0; + size_t t_kv = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + t_q = input_Q->data.shape[0]; + t_kv = input_K->data.shape[0]; + } auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); @@ -622,9 +669,9 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) fused_attn_arbitrary_seqlen_fwd( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, is_training, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, - input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, is_training, attn_scale, + dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, + input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else @@ -681,6 +728,13 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso size_t h_kv = input_K->data.shape[ndim - 2]; size_t d_qk = input_Q->data.shape[ndim - 1]; size_t d_v = input_V->data.shape[ndim - 1]; + size_t t_q = 0; + size_t t_kv = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + t_q = input_Q->data.shape[0]; + t_kv = input_K->data.shape[0]; + } auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); @@ -711,10 +765,10 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso input_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); } fused_attn_arbitrary_seqlen_bwd( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, input_Q, - input_K, input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK, output_dV, - output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, + input_Q, input_K, input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK, + output_dV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else const char *err_msg = diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 9eff62debf..f242502261 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -13,6 +13,7 @@ #include #include "../common.h" +#include "../cudnn_utils.h" #include "../util/cuda_runtime.h" #include "../util/system.h" #include "fused_attn_f16_arbitrary_seqlen.h" @@ -49,15 +50,16 @@ namespace transformer_engine { namespace fused_attn { void fused_attn_arbitrary_seqlen_fwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, - int64_t bias_b, int64_t bias_h, bool is_training, float scaling_factor, - float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, void *devPtrQ, - void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxStats, void *devPtrO, - void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, - void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, - cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, - cudaStream_t stream, cudnnHandle_t handle) { + int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, + bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, + int64_t window_size_right, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, + void *devPtrSoftmaxStats, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, + void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, + void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, + size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; + bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || @@ -72,10 +74,19 @@ void fused_attn_arbitrary_seqlen_fwd_impl( (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); bool is_dropout = (is_training && dropout_probability != 0.0f); bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); - if (is_ragged) { + const auto cudnn_runtime_version = cudnnGetVersion(); + + // keep original batch size because cu_seqlens are created with [b+1] shape + int64_t actual_b = b; + if (is_ragged && cudnn_runtime_version >= 90600) { NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"); + // replace batch size and maximum sequence lengths with maximum token counts + // for query and key/value so the graph is static within each quantization bucket + b = max_b; + s_q = max_t_q; + s_kv = max_t_kv; } - auto cudnn_runtime_version = cudnnGetVersion(); + const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; try { FADescriptor_v1 descriptor{b, @@ -115,6 +126,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr, // offset_k std::shared_ptr, // offset_v std::shared_ptr, // offset_o + std::shared_ptr, // offset_stats std::shared_ptr, // dropout_seed std::shared_ptr>; // dropout_offset @@ -138,30 +150,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr Q, K, V, attn_scale; std::shared_ptr bias, seq_q, seq_kv; - std::shared_ptr offset_q, offset_k, offset_v, offset_o; + std::shared_ptr offset_q, offset_k, offset_v, offset_o, + offset_stats; std::shared_ptr dropout_seed, dropout_offset; - offset_q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_q") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - offset_k = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_k") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - offset_v = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_v") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - offset_o = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_o") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); @@ -173,6 +165,21 @@ void fused_attn_arbitrary_seqlen_fwd_impl( NVTE_QKV_Matrix::NVTE_V_Matrix); if (is_ragged) { + offset_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_q") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + offset_k = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_k") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + offset_v = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_v") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); Q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") .set_dim({b, h, s_q, d_qk}) @@ -266,6 +273,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); if (is_ragged) { + offset_o = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_o") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); O->set_output(true) .set_dim({b, h, s_q, d_v}) .set_stride(o_stride) @@ -274,10 +286,24 @@ void fused_attn_arbitrary_seqlen_fwd_impl( O->set_output(true).set_dim({b, h, s_q, d_v}).set_stride(o_stride); } - Stats->set_output(true) - .set_data_type(fe::DataType_t::FLOAT) - .set_dim({b, h, s_q, 1}) - .set_stride({h * s_q, s_q, 1, 1}); + if (is_ragged && cudnn_runtime_version >= 90600) { + offset_stats = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_stats") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + Stats->set_output(true) + .set_data_type(fe::DataType_t::FLOAT) + .set_dim({b, h, s_q, 1}) + .set_stride({h * s_q, 1, h, 1}) + .set_ragged_offset(offset_stats); + } else { + Stats->set_output(true) + .set_data_type(fe::DataType_t::FLOAT) + .set_dim({b, h, s_q, 1}) + .set_stride({h * s_q, s_q, 1, 1}); + } std::tuple, // Q std::shared_ptr, // K @@ -289,8 +315,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); - auto offset_tuple = is_ragged ? std::make_tuple(offset_q, offset_k, offset_v, offset_o) - : std::make_tuple(nullptr, nullptr, nullptr, nullptr); + auto offset_qkvo_tuple = is_ragged ? std::make_tuple(offset_q, offset_k, offset_v, offset_o) + : std::make_tuple(nullptr, nullptr, nullptr, nullptr); + auto offset_s_tuple = (is_ragged && cudnn_runtime_version >= 90600) + ? std::make_tuple(offset_stats) + : std::make_tuple(nullptr); auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr); @@ -300,21 +329,34 @@ void fused_attn_arbitrary_seqlen_fwd_impl( NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, - bias_tuple, padding_tuple, offset_tuple, dropout_tuple); + auto return_tuple = + std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, + padding_tuple, offset_qkvo_tuple, offset_s_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, seq_q, seq_kv, offset_q, offset_k, - offset_v, offset_o, dropout_seed, dropout_offset] = + offset_v, offset_o, offset_stats, dropout_seed, dropout_offset] = get_graph(sdpa_f16_fprop_cache, descriptor); - auto plan_workspace_size = mha_graph->get_workspace_size(); // Exit to request upper level API to allocate memory if needed - size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t); - size_t seqlen_offsets_workspace_size = 4 * (b + 1) * sizeof(int32_t); + // n.b. Care should be taken to align each of the added worksapce tensors to their type. + // We do this by adding padding at the end of each separate allocation. + auto plan_workspace_size = alignTo<16>(mha_graph->get_workspace_size()); + const size_t num_bytes_per_seqlen = alignTo<16>(b * sizeof(int32_t)); + const size_t actual_seqlen_workspace_size = is_padding ? 2 * num_bytes_per_seqlen : 0; + const size_t num_bytes_per_ragged_offset = + alignTo<16>((b + 1) * typeToSize(ragged_offset_type)); + size_t seqlen_offsets_workspace_size = 0; + if (is_ragged) { + if (cudnn_runtime_version >= 90600) { + seqlen_offsets_workspace_size = 5 * num_bytes_per_ragged_offset; + } else { + seqlen_offsets_workspace_size = 4 * num_bytes_per_ragged_offset; + } + } if (workspace == nullptr) { *workspace_size = plan_workspace_size + actual_seqlen_workspace_size + seqlen_offsets_workspace_size; @@ -339,9 +381,9 @@ void fused_attn_arbitrary_seqlen_fwd_impl( constexpr size_t nthreads_per_block = 128; const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; - void *devActualSeqlenKV = static_cast(devActualSeqlenQ) + b * sizeof(int32_t); + void *devActualSeqlenKV = static_cast(devActualSeqlenQ) + num_bytes_per_seqlen; cu_seqlens_to_actual_seqlens<<>>( - b, static_cast(devPtrCuSeqlensQ), + actual_b, b, static_cast(devPtrCuSeqlensQ), static_cast(devPtrCuSeqlensKV), static_cast(devActualSeqlenQ), static_cast(devActualSeqlenKV)); variant_pack[seq_q] = devActualSeqlenQ; @@ -353,19 +395,25 @@ void fused_attn_arbitrary_seqlen_fwd_impl( const size_t grid = (b + nthreads_per_block) / nthreads_per_block; void *devOffsetsQ = static_cast(workspace) + plan_workspace_size + actual_seqlen_workspace_size; - void *devOffsetsK = static_cast(devOffsetsQ) + (b + 1) * sizeof(int32_t); - void *devOffsetsV = static_cast(devOffsetsK) + (b + 1) * sizeof(int32_t); - void *devOffsetsO = static_cast(devOffsetsV) + (b + 1) * sizeof(int32_t); - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); + void *devOffsetsK = static_cast(devOffsetsQ) + num_bytes_per_ragged_offset; + void *devOffsetsV = static_cast(devOffsetsK) + num_bytes_per_ragged_offset; + void *devOffsetsO = static_cast(devOffsetsV) + num_bytes_per_ragged_offset; + void *devOffsetsS = nullptr; + if (cudnn_runtime_version >= 90600) { + devOffsetsS = static_cast(devOffsetsO) + num_bytes_per_ragged_offset; + } + const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); cu_seqlens_padded_to_offsets<<>>( - layout_group, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), - static_cast(devPtrSeqOffsetsKV), static_cast(devOffsetsQ), - static_cast(devOffsetsK), static_cast(devOffsetsV), - static_cast(devOffsetsO)); + layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), + static_cast(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, + devOffsetsV, devOffsetsO, devOffsetsS); variant_pack[offset_q] = devOffsetsQ; variant_pack[offset_k] = devOffsetsK; variant_pack[offset_v] = devOffsetsV; variant_pack[offset_o] = devOffsetsO; + if (cudnn_runtime_version >= 90600) { + variant_pack[offset_stats] = devOffsetsS; + } } if (is_dropout) { @@ -380,16 +428,18 @@ void fused_attn_arbitrary_seqlen_fwd_impl( void fused_attn_arbitrary_seqlen_bwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, - int64_t bias_b, int64_t bias_h, float scaling_factor, float dropout_probability, - NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, void *devPtrQ, - void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, - void *devPtrBias, void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, - void *devPtrdBias, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, + int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, + float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, void *devPtrQ, void *devPtrKTranspose, + void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, + void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias, + void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; + bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); bool is_causal = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || @@ -404,9 +454,23 @@ void fused_attn_arbitrary_seqlen_bwd_impl( (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); bool is_dropout = (dropout_probability != 0.0f); bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); - auto cudnn_runtime_version = cudnnGetVersion(); + const auto cudnn_runtime_version = cudnnGetVersion(); const int device_id = cuda::current_device(); const int sm_arch_ = cuda::sm_arch(device_id); + // keep original batch size because cu_seqlens are created with [b+1] shape + int64_t actual_b = b; + if (is_ragged && cudnn_runtime_version >= 90600) { + NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"); + // replace batch size and maximum sequence lengths with maximum token counts + // for query and key/value so the graph is static within each quantization bucket + b = max_b; + s_q = max_t_q; + s_kv = max_t_kv; + } + + // We choose between 32-bit and 64-bit offsets depending on need. + // This allows us to support older cuDNN runtimes gracefully. + const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; try { FADescriptor_v1 descriptor{b, @@ -451,6 +515,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( std::shared_ptr, // offset_k std::shared_ptr, // offset_v std::shared_ptr, // offset_o + std::shared_ptr, // offset_stats std::shared_ptr, // dropout_seed std::shared_ptr>; // dropout_offset @@ -474,29 +539,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl( std::shared_ptr q, k, v, o, dO, stats, attn_scale; std::shared_ptr bias, dBias, seq_q, seq_kv; - std::shared_ptr offset_q, offset_k, offset_v, offset_o; + std::shared_ptr offset_q, offset_k, offset_v, offset_o, + offset_stats; std::shared_ptr dropout_seed, dropout_offset; - offset_q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_q") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - offset_k = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_k") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - offset_v = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_v") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - offset_o = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_o") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); std::vector q_stride(4); std::vector k_stride(4); std::vector v_stride(4); @@ -511,6 +557,26 @@ void fused_attn_arbitrary_seqlen_bwd_impl( NVTE_QKV_Matrix::NVTE_O_Matrix); if (is_ragged) { + offset_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_q") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + offset_k = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_k") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + offset_v = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_v") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + offset_o = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_o") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Q") .set_dim({b, h, s_q, d_qk}) @@ -558,11 +624,26 @@ void fused_attn_arbitrary_seqlen_bwd_impl( .set_dim({b, h, s_q, d_v}) .set_stride(o_stride)); } - stats = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("stats") - .set_dim({b, h, s_q, 1}) - .set_stride({h * s_q, s_q, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); + if (is_ragged && cudnn_runtime_version >= 90600) { + offset_stats = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_stats") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + stats = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("stats") + .set_dim({b, h, s_q, 1}) + .set_stride({h * s_q, 1, h, 1}) + .set_data_type(fe::DataType_t::FLOAT) + .set_ragged_offset(offset_stats)); + } else { + stats = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("stats") + .set_dim({b, h, s_q, 1}) + .set_stride({h * s_q, s_q, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + } attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("attn_scale") @@ -578,6 +659,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl( .set_causal_mask_bottom_right(is_bottom_right) .set_attn_scale(attn_scale); + if (is_ragged && cudnn_runtime_version >= 90600) { + sdpa_backward_options.set_max_total_seq_len_q(s_q); + } + if (cudnn_runtime_version >= 90200 && window_size_left != -1) { sdpa_backward_options.set_sliding_window_length(window_size_left + 1); } @@ -671,8 +756,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl( auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); - auto offset_tuple = is_ragged ? std::make_tuple(offset_q, offset_k, offset_v, offset_o) - : std::make_tuple(nullptr, nullptr, nullptr, nullptr); + auto offset_qkvo_tuple = is_ragged ? std::make_tuple(offset_q, offset_k, offset_v, offset_o) + : std::make_tuple(nullptr, nullptr, nullptr, nullptr); + auto offset_s_tuple = (is_ragged && cudnn_runtime_version >= 90600) + ? std::make_tuple(offset_stats) + : std::make_tuple(nullptr); auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr); @@ -682,22 +770,34 @@ void fused_attn_arbitrary_seqlen_bwd_impl( NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, - padding_tuple, offset_tuple, dropout_tuple); + auto return_tuple = + std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, padding_tuple, + offset_qkvo_tuple, offset_s_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV, bias, dBias, seq_q, seq_kv, - offset_q, offset_k, offset_v, offset_o, dropout_seed, dropout_offset] = + offset_q, offset_k, offset_v, offset_o, offset_stats, dropout_seed, dropout_offset] = get_graph(sdpa_f16_bprop_cache, descriptor); - auto plan_workspace_size = mha_graph->get_workspace_size(); - // Exit to request upper level API to allocate memory if needed - size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t); - size_t seqlen_offsets_workspace_size = 4 * (b + 1) * sizeof(int32_t); + // n.b. Care should be taken to align each of the added worksapce tensors to their type. + // We do this by adding padding at the end of each separate allocation. + auto plan_workspace_size = alignTo<16>(mha_graph->get_workspace_size()); + const size_t num_bytes_per_seqlen = alignTo<16>(b * sizeof(int32_t)); + const size_t actual_seqlen_workspace_size = is_padding ? 2 * num_bytes_per_seqlen : 0; + const size_t num_bytes_per_ragged_offset = + alignTo<16>((b + 1) * typeToSize(ragged_offset_type)); + size_t seqlen_offsets_workspace_size = 0; + if (is_ragged) { + if (cudnn_runtime_version >= 90600) { + seqlen_offsets_workspace_size = 5 * num_bytes_per_ragged_offset; + } else { + seqlen_offsets_workspace_size = 4 * num_bytes_per_ragged_offset; + } + } if (workspace == nullptr) { *workspace_size = plan_workspace_size + actual_seqlen_workspace_size + seqlen_offsets_workspace_size; @@ -735,9 +835,9 @@ void fused_attn_arbitrary_seqlen_bwd_impl( constexpr size_t nthreads_per_block = 128; const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; - void *devActualSeqlenKV = static_cast(devActualSeqlenQ) + b * sizeof(int32_t); + void *devActualSeqlenKV = static_cast(devActualSeqlenQ) + num_bytes_per_seqlen; cu_seqlens_to_actual_seqlens<<>>( - b, static_cast(devPtrCuSeqlensQ), + actual_b, b, static_cast(devPtrCuSeqlensQ), static_cast(devPtrCuSeqlensKV), static_cast(devActualSeqlenQ), static_cast(devActualSeqlenKV)); variant_pack[seq_q] = devActualSeqlenQ; @@ -749,19 +849,25 @@ void fused_attn_arbitrary_seqlen_bwd_impl( const size_t grid = (b + nthreads_per_block) / nthreads_per_block; void *devOffsetsQ = static_cast(workspace) + plan_workspace_size + actual_seqlen_workspace_size; - void *devOffsetsK = static_cast(devOffsetsQ) + (b + 1) * sizeof(int32_t); - void *devOffsetsV = static_cast(devOffsetsK) + (b + 1) * sizeof(int32_t); - void *devOffsetsO = static_cast(devOffsetsV) + (b + 1) * sizeof(int32_t); - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); + void *devOffsetsK = static_cast(devOffsetsQ) + num_bytes_per_ragged_offset; + void *devOffsetsV = static_cast(devOffsetsK) + num_bytes_per_ragged_offset; + void *devOffsetsO = static_cast(devOffsetsV) + num_bytes_per_ragged_offset; + void *devOffsetsS = nullptr; + if (cudnn_runtime_version >= 90600) { + devOffsetsS = static_cast(devOffsetsO) + num_bytes_per_ragged_offset; + } + const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); cu_seqlens_padded_to_offsets<<>>( - layout_group, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), - static_cast(devPtrSeqOffsetsKV), static_cast(devOffsetsQ), - static_cast(devOffsetsK), static_cast(devOffsetsV), - static_cast(devOffsetsO)); + layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), + static_cast(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, + devOffsetsV, devOffsetsO, devOffsetsS); variant_pack[offset_q] = devOffsetsQ; variant_pack[offset_k] = devOffsetsK; variant_pack[offset_v] = devOffsetsV; variant_pack[offset_o] = devOffsetsO; + if (cudnn_runtime_version >= 90600) { + variant_pack[offset_stats] = devOffsetsS; + } } if (is_dropout) { @@ -778,10 +884,10 @@ void fused_attn_arbitrary_seqlen_bwd_impl( using namespace transformer_engine::fused_attn; void fused_attn_arbitrary_seqlen_fwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, + size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, + bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, + int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -789,6 +895,7 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( const auto QKV_type = input_QKV->data.dtype; void *devPtrQKV = input_QKV->data.dptr; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); size_t stride = 0; if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { stride = typeToSize(QKV_type) * num_attn_heads * head_dim; @@ -807,17 +914,30 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( bias_b = input_Bias->data.shape[0]; bias_h = input_Bias->data.shape[1]; } + void *devPtrO = output_O->data.dptr; void *devPtrS = nullptr; void *devPtrCuSeqlens = cu_seqlens->data.dptr; void *devPtrSeqOffsets = cu_seqlens_padded->data.dptr; + size_t max_batch_size = 0; + size_t max_tokens = 0; + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + max_batch_size = get_max_batch_size(batch); + max_tokens = get_max_tokens(num_tokens); + } + if (Aux_CTX_Tensors->size == 0) { + const auto cudnn_runtime_version = cudnnGetVersion(); if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { Aux_CTX_Tensors->size = 3; Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; + if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {max_tokens, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; + } output_S->data.dtype = DType::kFloat32; Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; @@ -831,7 +951,11 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( Aux_CTX_Tensors->size = 2; Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; + if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {max_tokens, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; + } output_S->data.dtype = DType::kFloat32; Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; @@ -861,12 +985,12 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( size_t workspace_size = 0; fused_attn_arbitrary_seqlen_fwd_impl( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, bias_b, - bias_h, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, - window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, - devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, - devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, - handle); + batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, + max_batch_size, max_tokens, max_tokens, bias_b, bias_h, is_training, attn_scale, p_dropout, + qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, devPtrK, + devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, + devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets, + get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -884,10 +1008,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( } void fused_attn_arbitrary_seqlen_bwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, - const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, + size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, + bool deterministic, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { @@ -895,7 +1019,6 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( const auto QKV_type = input_QKV->data.dtype; void *devPtrQKV = input_QKV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); size_t stride = 0; if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { @@ -920,6 +1043,14 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( bias_h = output_dBias->data.shape[1]; } + size_t max_batch_size = 0; + size_t max_tokens = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + max_batch_size = get_max_batch_size(batch); + max_tokens = get_max_tokens(num_tokens); + } + void *devPtrdQKV = output_dQKV->data.dptr; void *devPtrdQ = devPtrdQKV; void *devPtrdK = static_cast(static_cast(devPtrdQKV) + stride); @@ -938,12 +1069,13 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( size_t workspace_size = 0; fused_attn_arbitrary_seqlen_bwd_impl( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, bias_b, - bias_h, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, - devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); + batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, + max_batch_size, max_tokens, max_tokens, bias_b, bias_h, attn_scale, p_dropout, qkv_layout, + bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, devPtrK, + devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, + devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, + devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, + &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -961,19 +1093,21 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( } void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, - const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, + bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, + int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV, + const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; void *devPtrQ = input_Q->data.dptr; void *devPtrKV = input_KV->data.dptr; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); size_t stride = 0; if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { stride = typeToSize(QKV_type) * num_gqa_groups * head_dim; @@ -991,6 +1125,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( bias_b = input_Bias->data.shape[0]; bias_h = input_Bias->data.shape[1]; } + void *devPtrO = output_O->data.dptr; void *devPtrS = nullptr; @@ -999,12 +1134,26 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr; void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr; + size_t max_batch_size = 0; + size_t max_tokens_q = 0; + size_t max_tokens_kv = 0; + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + max_batch_size = get_max_batch_size(batch); + max_tokens_q = get_max_tokens(num_tokens_q); + max_tokens_kv = get_max_tokens(num_tokens_kv); + } + if (Aux_CTX_Tensors->size == 0) { + const auto cudnn_runtime_version = cudnnGetVersion(); if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { Aux_CTX_Tensors->size = 3; Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } output_S->data.dtype = DType::kFloat32; Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; @@ -1018,7 +1167,11 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( Aux_CTX_Tensors->size = 2; Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } output_S->data.dtype = DType::kFloat32; Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; @@ -1049,11 +1202,11 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, - bias_b, bias_h, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, - window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, - devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, - devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); + max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, is_training, attn_scale, + p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, + devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, + devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, + get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1072,12 +1225,13 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( void fused_attn_arbitrary_seqlen_bwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, - const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, - Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, + bool deterministic, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, + const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, + Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -1108,6 +1262,16 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( bias_h = output_dBias->data.shape[1]; } + size_t max_batch_size = 0; + size_t max_tokens_q = 0; + size_t max_tokens_kv = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + max_batch_size = get_max_batch_size(batch); + max_tokens_q = get_max_tokens(num_tokens_q); + max_tokens_kv = get_max_tokens(num_tokens_kv); + } + void *devPtrdQ = output_dQ->data.dptr; void *devPtrdKV = output_dKV->data.dptr; void *devPtrdK = devPtrdKV; @@ -1129,12 +1293,12 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( fused_attn_arbitrary_seqlen_bwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, - bias_b, bias_h, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, - devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, - devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, - stream, handle); + max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, + qkv_layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, + devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, + devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, + devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), + workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1153,8 +1317,9 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( void fused_attn_arbitrary_seqlen_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, + size_t num_tokens_kv, bool is_training, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, @@ -1163,6 +1328,7 @@ void fused_attn_arbitrary_seqlen_fwd( using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); void *devPtrQ = input_Q->data.dptr; void *devPtrK = input_K->data.dptr; void *devPtrV = input_V->data.dptr; @@ -1182,12 +1348,26 @@ void fused_attn_arbitrary_seqlen_fwd( void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr; void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr; + size_t max_batch_size = 0; + size_t max_tokens_q = 0; + size_t max_tokens_kv = 0; + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + max_batch_size = get_max_batch_size(batch); + max_tokens_q = get_max_tokens(num_tokens_q); + max_tokens_kv = get_max_tokens(num_tokens_kv); + } + if (Aux_CTX_Tensors->size == 0) { + const auto cudnn_runtime_version = cudnnGetVersion(); if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { Aux_CTX_Tensors->size = 3; Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } output_S->data.dtype = DType::kFloat32; Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; @@ -1201,7 +1381,11 @@ void fused_attn_arbitrary_seqlen_fwd( Aux_CTX_Tensors->size = 2; Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } output_S->data.dtype = DType::kFloat32; Tensor *output_rng_state = reinterpret_cast(Aux_CTX_Tensors->tensors[1]); output_rng_state->data.dptr = nullptr; @@ -1232,11 +1416,11 @@ void fused_attn_arbitrary_seqlen_fwd( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, - bias_b, bias_h, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, - window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, - devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, - devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); + max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, is_training, attn_scale, + p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, + devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, + devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, + get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1255,13 +1439,13 @@ void fused_attn_arbitrary_seqlen_fwd( void fused_attn_arbitrary_seqlen_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, - Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, + size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K, + const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, + Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -1282,6 +1466,16 @@ void fused_attn_arbitrary_seqlen_bwd( bias_h = output_dBias->data.shape[1]; } + size_t max_batch_size = 0; + size_t max_tokens_q = 0; + size_t max_tokens_kv = 0; + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + max_batch_size = get_max_batch_size(batch); + max_tokens_q = get_max_tokens(num_tokens_q); + max_tokens_kv = get_max_tokens(num_tokens_kv); + } + void *devPtrdQ = output_dQ->data.dptr; void *devPtrdK = output_dK->data.dptr; void *devPtrdV = output_dV->data.dptr; @@ -1301,12 +1495,12 @@ void fused_attn_arbitrary_seqlen_bwd( fused_attn_arbitrary_seqlen_bwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, - bias_b, bias_h, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, - devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, - devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, - stream, handle); + max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, + qkv_layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, + devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, + devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, + devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), + workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 4b523cca1a..3a1216f891 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -19,47 +19,50 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8900) void fused_attn_arbitrary_seqlen_fwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, + size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, + bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, + int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, - const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, + size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, + bool deterministic, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, - const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, + bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, + int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV, + const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, - const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, - Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, Tensor *output_dBias, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, + bool deterministic, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, + const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, + Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, + size_t num_tokens_kv, bool is_training, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, @@ -68,13 +71,13 @@ void fused_attn_arbitrary_seqlen_fwd( void fused_attn_arbitrary_seqlen_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, - Tensor *output_dV, Tensor *output_dBias, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, + size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K, + const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, + Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu index 88c1490c01..9341ebf5f9 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu @@ -12,6 +12,7 @@ #include #include "../common.h" +#include "../cudnn_utils.h" #include "fused_attn_f16_max512_seqlen.h" #include "utils.h" @@ -746,7 +747,7 @@ void fused_attn_max_512_fwd_impl( void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; void *devActualSeqlenK = static_cast(devActualSeqlenQ) + b * sizeof(int32_t); cu_seqlens_to_actual_seqlens<<>>( - b, static_cast(devPtrCuSeqlenQ), + b, b, static_cast(devPtrCuSeqlenQ), static_cast(devPtrCuSeqlenKV), static_cast(devActualSeqlenQ), static_cast(devActualSeqlenK)); NVTE_CHECK_CUDA(cudaGetLastError()); @@ -1169,7 +1170,7 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; void *devActualSeqlenK = static_cast(devActualSeqlenQ) + b * sizeof(int32_t); cu_seqlens_to_actual_seqlens<<>>( - b, static_cast(devPtrCuSeqlenQ), + b, b, static_cast(devPtrCuSeqlenQ), static_cast(devPtrCuSeqlenKV), static_cast(devActualSeqlenQ), static_cast(devActualSeqlenK)); NVTE_CHECK_CUDA(cudaGetLastError()); diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index fb7765e1a4..f8fe458219 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -5,6 +5,7 @@ ************************************************************************/ #include "../common.h" +#include "../cudnn_utils.h" #include "../util/system.h" #include "fused_attn_fp8.h" #include "utils.h" diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index 56dbb278b4..a053c55fb6 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -4,7 +4,11 @@ * See LICENSE for license information. ************************************************************************/ +#include +#include + #include "../common.h" +#include "../cudnn_utils.h" #include "transformer_engine/fused_attn.h" #include "utils.h" @@ -337,7 +341,7 @@ cudnn_frontend::Operation ternary_pw_op_create(cudnn_frontend::Tensor const &xDe } // convert cu_seqlens_q to qkv/o_ragged_offset and actual_seqlens_q -__global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d, int32_t *cu_seqlens_q, +__global__ void cu_seqlens_to_offsets(int64_t b, int64_t h, int64_t d, int32_t *cu_seqlens_q, int32_t *actual_seqlens_q, int32_t *qkv_ragged_offset, int32_t *o_ragged_offset) { size_t tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -351,93 +355,145 @@ __global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d, int32_t *cu_ } // convert cu_seqlens to actual_seqlens -__global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu_seqlens, +__global__ void cu_seqlens_to_actual_seqlens(int64_t actual_b, int64_t max_b, + int32_t const *const q_cu_seqlens, int32_t const *const kv_cu_seqlens, int32_t *q_seqlens, int32_t *kv_seqlens) { size_t tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < b) { + if (tid < actual_b) { q_seqlens[tid] = q_cu_seqlens[tid + 1] - q_cu_seqlens[tid]; kv_seqlens[tid] = kv_cu_seqlens[tid + 1] - kv_cu_seqlens[tid]; + } else if (tid < max_b) { + q_seqlens[tid] = 0; + kv_seqlens[tid] = 0; } } // convert cu_seqlens_padded to offsets -__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, size_t b, size_t h, - size_t hg, size_t d_qk, size_t d_v, - int32_t *cu_seqlens_q_padded, - int32_t *cu_seqlens_kv_padded, int32_t *offsets_q, - int32_t *offsets_k, int32_t *offsets_v, - int32_t *offsets_o) { +template +__device__ void cu_seqlens_padded_to_offsets_impl( + NVTE_QKV_Layout_Group layout_group, int64_t actual_b, int64_t max_b, int64_t h, int64_t hg, + int64_t d_qk, int64_t d_v, const int32_t *cu_seqlens_q_padded, + const int32_t *cu_seqlens_kv_padded, OFFSETS_T *offsets_q, OFFSETS_T *offsets_k, + OFFSETS_T *offsets_v, OFFSETS_T *offsets_o, OFFSETS_T *offsets_s) { size_t tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < b + 1) { - offsets_o[tid] = h * d_v * cu_seqlens_q_padded[tid]; + auto cu_seqlens_id = min(tid, actual_b); + if (tid <= max_b) { + offsets_o[tid] = h * d_v * cu_seqlens_q_padded[cu_seqlens_id]; + if (offsets_s != nullptr) { + offsets_s[tid] = h * cu_seqlens_q_padded[cu_seqlens_id]; + } switch (layout_group) { case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: - offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[tid]; - offsets_k[tid] = hg * d_qk * cu_seqlens_kv_padded[tid]; - offsets_v[tid] = hg * d_v * cu_seqlens_kv_padded[tid]; + offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[cu_seqlens_id]; + offsets_k[tid] = hg * d_qk * cu_seqlens_kv_padded[cu_seqlens_id]; + offsets_v[tid] = hg * d_v * cu_seqlens_kv_padded[cu_seqlens_id]; break; case NVTE_QKV_Layout_Group::NVTE_3HD: case NVTE_QKV_Layout_Group::NVTE_H3D: - offsets_q[tid] = 3 * h * d_qk * cu_seqlens_q_padded[tid]; - offsets_k[tid] = offsets_q[tid]; - offsets_v[tid] = offsets_q[tid]; + offsets_q[tid] = 3 * h * d_qk * cu_seqlens_q_padded[cu_seqlens_id]; + offsets_k[tid] = offsets_q[cu_seqlens_id]; + offsets_v[tid] = offsets_q[cu_seqlens_id]; break; case NVTE_QKV_Layout_Group::NVTE_HD_2HD: case NVTE_QKV_Layout_Group::NVTE_HD_H2D: - offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[tid]; - offsets_k[tid] = 2 * hg * d_qk * cu_seqlens_kv_padded[tid]; - offsets_v[tid] = offsets_k[tid]; + offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[cu_seqlens_id]; + offsets_k[tid] = 2 * hg * d_qk * cu_seqlens_kv_padded[cu_seqlens_id]; + offsets_v[tid] = offsets_k[cu_seqlens_id]; break; } } } -} // namespace fused_attn +__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, int64_t actual_b, + int64_t max_b, int64_t h, int64_t hg, int64_t d_qk, + int64_t d_v, const int32_t *cu_seqlens_q_padded, + const int32_t *cu_seqlens_kv_padded, + DType offset_dtype, void *offsets_q, void *offsets_k, + void *offsets_v, void *offsets_o, void *offsets_s) { + if (offset_dtype == DType::kInt32) { + cu_seqlens_padded_to_offsets_impl( + layout_group, actual_b, max_b, h, hg, d_qk, d_v, cu_seqlens_q_padded, cu_seqlens_kv_padded, + reinterpret_cast(offsets_q), reinterpret_cast(offsets_k), + reinterpret_cast(offsets_v), reinterpret_cast(offsets_o), + reinterpret_cast(offsets_s)); + } else { + assert(offset_dtype == DType::kInt64 && "expect int64"); + cu_seqlens_padded_to_offsets_impl( + layout_group, actual_b, max_b, h, hg, d_qk, d_v, cu_seqlens_q_padded, cu_seqlens_kv_padded, + reinterpret_cast(offsets_q), reinterpret_cast(offsets_k), + reinterpret_cast(offsets_v), reinterpret_cast(offsets_o), + reinterpret_cast(offsets_s)); + } +} -// get cuDNN data type -cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t) { - using namespace transformer_engine; - switch (t) { - case DType::kInt32: - return CUDNN_DATA_INT32; - case DType::kInt64: - return CUDNN_DATA_INT64; - case DType::kFloat16: - return CUDNN_DATA_HALF; - case DType::kFloat32: - return CUDNN_DATA_FLOAT; - case DType::kBFloat16: - return CUDNN_DATA_BFLOAT16; - case DType::kFloat8E4M3: - return CUDNN_DATA_FP8_E4M3; - case DType::kFloat8E5M2: - return CUDNN_DATA_FP8_E5M2; - default: - NVTE_ERROR("Invalid cuDNN data type. \n"); +DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_attn_heads, + int64_t num_gqa_groups, int64_t max_seqlen_q, int64_t max_seqlen_kv, + int64_t head_dim_qk, int64_t head_dim_v) { + std::array offsets_qkvo{}; + switch (layout_group) { + case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: + offsets_qkvo[0] = num_attn_heads * head_dim_qk * max_seqlen_q; + offsets_qkvo[1] = num_gqa_groups * head_dim_qk * max_seqlen_kv; + offsets_qkvo[2] = num_gqa_groups * head_dim_v * max_seqlen_kv; + break; + case NVTE_QKV_Layout_Group::NVTE_3HD: + case NVTE_QKV_Layout_Group::NVTE_H3D: + offsets_qkvo[0] = 3 * num_attn_heads * head_dim_qk * max_seqlen_q; + offsets_qkvo[1] = offsets_qkvo[0]; + offsets_qkvo[2] = offsets_qkvo[0]; + break; + case NVTE_QKV_Layout_Group::NVTE_HD_2HD: + case NVTE_QKV_Layout_Group::NVTE_HD_H2D: + offsets_qkvo[0] = num_attn_heads * head_dim_qk * max_seqlen_q; + offsets_qkvo[1] = 2 * num_gqa_groups * head_dim_qk * max_seqlen_kv; + offsets_qkvo[2] = offsets_qkvo[1]; + break; } + + offsets_qkvo[3] = num_attn_heads * head_dim_qk * max_seqlen_q; + + size_t max_offset = *std::max_element(offsets_qkvo.begin(), offsets_qkvo.end()); + if (max_offset > std::numeric_limits::max()) { + return DType::kInt64; + } + + return DType::kInt32; } -// get cuDNN data type -cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t) { - using namespace transformer_engine; - switch (t) { - case DType::kInt32: - return cudnn_frontend::DataType_t::INT32; - case DType::kInt64: - return cudnn_frontend::DataType_t::INT64; - case DType::kFloat16: - return cudnn_frontend::DataType_t::HALF; - case DType::kFloat32: - return cudnn_frontend::DataType_t::FLOAT; - case DType::kBFloat16: - return cudnn_frontend::DataType_t::BFLOAT16; - case DType::kFloat8E4M3: - return cudnn_frontend::DataType_t::FP8_E4M3; - case DType::kFloat8E5M2: - return cudnn_frontend::DataType_t::FP8_E5M2; - default: - NVTE_ERROR("Invalid cuDNN data type. \n"); +// quantize batch size +size_t get_max_batch_size(size_t batch_size) { + size_t max_b = batch_size; + size_t log2_b = ceil(log2(batch_size)); + // batch size is expected to be 10s-100s + // b = 1, ..., 32 -> max_b = 32 + // b = 33, ..., 512 -> max_b = next power of 2 + // otherwise -> max_b = b + if (log2_b <= 5) { + max_b = 32; + } else if (log2_b <= 9) { + max_b = pow(2, log2_b); } + return max_b; } + +// quantize token count +size_t get_max_tokens(size_t num_tokens) { + // token count is expected to be 1k's-100k's + // t = 0, ..., 1024 -> max_t = 1024 + // t = 1025, ..., 32k -> max_t = next power of 2 + // t = 32k+1, ... -> max_t = increment by 32k + size_t log2_t = ceil(log2(num_tokens)); + size_t max_t = 0; + if (log2_t <= 10) { + max_t = 1024; + } else if (log2_t <= 15) { + max_t = pow(2, log2_t); + } else { + max_t = (num_tokens + 32767) / 32768 * 32768; + } + return max_t; +} + +} // namespace fused_attn } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index d5cf450181..f790d3b567 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -118,43 +118,30 @@ struct FADescriptor_v1 { } }; -__global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d, int32_t *cu_seqlens_q, +__global__ void cu_seqlens_to_offsets(int64_t b, int64_t h, int64_t d, int32_t *cu_seqlens_q, int32_t *actual_seqlens_q, int32_t *qkv_ragged_offset, int32_t *o_ragged_offset); -__global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu_seqlens, +__global__ void cu_seqlens_to_actual_seqlens(int64_t actual_b, int64_t max_b, + int32_t const *const q_cu_seqlens, int32_t const *const kv_cu_seqlens, int32_t *q_seqlens, int32_t *kv_seqlens); -__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, size_t b, size_t h, - size_t hg, size_t d_qk, size_t d_v, - int32_t *cu_seqlens_q_padded, - int32_t *cu_seqlens_kv_padded, int32_t *offsets_q, - int32_t *offsets_k, int32_t *offsets_v, - int32_t *offsets_o); -} // namespace fused_attn - -cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t); -cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t); +__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, int64_t actual_b, + int64_t max_b, int64_t h, int64_t hg, int64_t d_qk, + int64_t d_v, const int32_t *cu_seqlens_q_padded, + const int32_t *cu_seqlens_kv_padded, + DType offset_dtype, void *offsets_q, void *offsets_k, + void *offsets_v, void *offsets_o, void *offsets_s); -class cudnnExecutionPlanManager { - public: - static cudnnExecutionPlanManager &Instance() { - static thread_local cudnnExecutionPlanManager instance; - return instance; - } +DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_attn_heads, + int64_t num_gqa_groups, int64_t max_seqlen_q, int64_t max_seqlen_kv, + int64_t head_dim_qk, int64_t head_dim_v); - cudnnHandle_t GetCudnnHandle() { - static thread_local std::once_flag flag; - std::call_once(flag, [&] { cudnnCreate(&handle_); }); - return handle_; - } +size_t get_max_batch_size(size_t batch_size); +size_t get_max_tokens(size_t num_tokens); - ~cudnnExecutionPlanManager() {} - - private: - cudnnHandle_t handle_ = nullptr; -}; +} // namespace fused_attn } // namespace transformer_engine #endif diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h new file mode 100644 index 0000000000..17ecca5ff0 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -0,0 +1,201 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_COMM_GEMM_OVERLAP_H_ +#define TRANSFORMER_ENGINE_COMMON_COMM_GEMM_OVERLAP_H_ + +#include +#include +#include + +#include + +#include "common/comm_gemm_overlap/userbuffers/userbuffers.h" + +#define NVTE_COMM_OVERLAP_MAX_STREAMS 3 + +namespace transformer_engine { + +/* \brief Check if Userbufers bootstraps with direct calls to MPI collectives. + * This can turned on by building Transformer Engine with the `NVTE_UB_WITH_MPI=1` option. + * + * \return True if Userbuffers is built with MPI + */ +bool ubuf_built_with_mpi(); + +enum class CommOverlapType { RS = 0, AG = 1 }; + +enum class CommOverlapAlgo { + BULK_OVERLAP_AG = 0, + BULK_OVERLAP_RS = 1, + SPLIT_PIPELINED_AG_P2P = 2, + SPLIT_PIPELINED_RS = 3, + SPLIT_PIPELINED_RS_P2P = 4, + ATOMIC_GEMM_RS = 5, + ATOMIC_GEMM_AG_P2P = 6, + ATOMIC_GEMM_RS_P2P = 7 +}; + +class CommOverlapCore { + protected: + static inline communicator *_ub_comm{nullptr}; + static inline bool _comm_created{false}; + + int _rank; + int _tp_id; + int _tp_size; + int _num_splits; + int _math_sms; + int _num_comm_sm; + int _cga_size; + int _use_ce; + int _ub_reg; + bool _atomic_gemm{false}; + bool _is_p2p{false}; + + TensorWrapper _ubuf; + TensorWrapper _counter; + float *_ubuf_scale_inv; + bool _ubuf_scale_inv_initialized{false}; + + std::vector _stream_compute; + cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm; + + public: + CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, + int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, + int num_splits, int num_max_streams, int comm_cga_size, int num_comm_sm, + bool set_sm_margin, bool use_ce, bool atomic_gemm); + + virtual ~CommOverlapCore(); + + void set_ubuf_scale_inv(float *scale_inv) { + _ubuf_scale_inv = scale_inv; + _ubuf_scale_inv_initialized = true; + } + + bool is_atomic_gemm() { return _atomic_gemm; } + + bool is_p2p_overlap() { return _is_p2p; } + + bool is_fp8_ubuf() { return _ubuf.element_size() == 1; } +}; // CommOverlapCore + +class CommOverlapBase : public CommOverlapCore { + protected: + int _rs_kernel_type; + cudaStream_t _stream_comm; + cudaEvent_t _start_d2dcopy; + + public: + CommOverlapBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, + int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, + ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits = 3, + int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, + int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false); + + virtual ~CommOverlapBase(); + + /* + ** Bulk GEMM + COMM + ** This function assumes the communication input is pre-copied to _ubuf + */ + void bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + CommOverlapType comm_type, TensorWrapper &rs_output, cudaStream_t stream_main); + + /* + ** Split FPROP GEMM + ReduceScatter + */ + void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, bool gemm_overlap, + TensorWrapper &rs_output, cudaStream_t stream_main); + + /* + ** Split FPROP GEMM + ReduceScatter + */ + void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, bool gemm_overlap, TensorWrapper &rs_output, + cudaStream_t stream_main); +}; // CommOverlapBase + +class CommOverlapP2PBase : public CommOverlapCore { + protected: + bool _is_reduce_scatter{false}; + bool _use_multiatomic_ag{false}; + + int _next_rank; + int _prev_rank; + int _rank_round_tp; + int _aggregate; + int _num_ubuf_chunks; + int _self_chunk_id; + + std::vector _ubufs; + + cudaStream_t _stream_send; + cudaStream_t _stream_recv; + cudaEvent_t _stop_send, _stop_recv; + + public: + CommOverlapP2PBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, + int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, + ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, + CommOverlapType comm_type, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, + int comm_cga_size = 1, int num_comm_sm = 1, bool set_sm_margin = false, + bool use_ce = true, bool atomic_gemm = false, bool aggregate = false); + + virtual ~CommOverlapP2PBase(); + + /* + ** Split AllGather + AtomicGEMM using P2P communication + ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG + ** outputs in each rank to be in the contiguous memory space after all ring exchange phases. + */ + void atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main); + + /* + ** Split AllGather + GEMM using P2P communication + ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG + ** outputs in each rank to be in the contiguous memory space after all ring exchange phases. + */ + void split_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main); + + /* + ** Split ReduceScatter + GEMM using P2P communication + */ + void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main); + + /* + ** Split ReduceScatter + GEMM using P2P communication + */ + void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main); +}; // CommOverlapP2PBase + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMM_GEMM_OVERLAP_H_ diff --git a/transformer_engine/common/include/transformer_engine/cudnn.h b/transformer_engine/common/include/transformer_engine/cudnn.h new file mode 100644 index 0000000000..c5e4bc23a9 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/cudnn.h @@ -0,0 +1,29 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file cudnn.h + * \brief Helper for cuDNN initialization + */ + +#ifndef TRANSFORMER_ENGINE_CUDNN_H_ +#define TRANSFORMER_ENGINE_CUDNN_H_ + +#include "transformer_engine.h" + +/*! \namespace transformer_engine + */ +namespace transformer_engine { + +/*! \brief TE/JAX cudaGraph requires the cuDNN initialization to happen outside of the capturing + * region. This function is a helper to call cudnnCreate() which allocate memory for the handle. + * The function will be called in the initialize() phase of the related XLA custom calls. + */ + +void nvte_cudnn_handle_init(); + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_CUDNN_H_ diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 534f6b2181..d302518235 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -24,7 +24,7 @@ extern "C" { enum NVTEDType { kNVTEByte = 0, /*!< Byte */ kNVTEInt32 = 1, /*!< 32-bit integer */ - kNVTEInt64 = 2, /*!< 32-bit integer */ + kNVTEInt64 = 2, /*!< 64-bit integer */ kNVTEFloat32 = 3, /*!< 32-bit float */ kNVTEFloat16 = 4, /*!< 16-bit float (E5M10) */ kNVTEBFloat16 = 5, /*!< 16-bit bfloat (E8M7) */ @@ -78,13 +78,13 @@ NVTETensor nvte_create_tensor(void *dptr, const NVTEShape shape, const NVTEDType */ void nvte_destroy_tensor(NVTETensor tensor); -/*! \brief Get a tensor's data type. +/*! \brief Get a raw pointer to the tensor's data. * * \param[in] tensor Tensor. * - * \return A data type of the input tensor. + * \return A raw pointer to tensor's data. */ -NVTEDType nvte_tensor_type(const NVTETensor tensor); +void *nvte_tensor_data(const NVTETensor tensor); /*! \brief Get a tensor's data shape. * @@ -94,13 +94,46 @@ NVTEDType nvte_tensor_type(const NVTETensor tensor); */ NVTEShape nvte_tensor_shape(const NVTETensor tensor); -/*! \brief Get a raw pointer to the tensor's data. +/*! \brief Get a tensor's number of dimensions. * * \param[in] tensor Tensor. * - * \return A raw pointer to tensor's data. + * \return Number of tensor dimensions. */ -void *nvte_tensor_data(const NVTETensor tensor); +size_t nvte_tensor_ndims(const NVTETensor tensor); + +/*! \brief Get the size of a specific tensor dimension. + * + * \param[in] tensor Tensor. + * \param[in] size_t Dimension index. + * + * \return Size of the tensor at the specified dimension. + */ +size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim); + +/*! \brief Get a tensor's total number of elements. + * + * \param[in] tensor Tensor. + * + * \return Number of elements in the tensor. + */ +size_t nvte_tensor_numel(const NVTETensor tensor); + +/*! \brief Get the byte size for the tensor's data type. + * + * \param[in] tensor Tensor. + * + * \return Byte size of the tensor's data type. + */ +size_t nvte_tensor_element_size(const NVTETensor tensor); + +/*! \brief Get a tensor's data type. + * + * \param[in] tensor Tensor. + * + * \return A data type of the input tensor. + */ +NVTEDType nvte_tensor_type(const NVTETensor tensor); /*! \brief Get a pointer to the tensor's amax data. * @@ -265,6 +298,56 @@ class TensorWrapper { return nvte_tensor_shape(tensor_); } + /*! \brief Get the size of this TensorWrapper in the given dimension. + * + * \param[in] size_t Dimension index. + * + * \return Size of this TensorWrapper in given dimension. + */ + size_t size(const size_t dim) const { + if (tensor_ == nullptr) return 0; + return nvte_tensor_size(tensor_, dim); + } + + /*! \brief Get the number of dimensions for this TensorWrapper. + * + * \return Number of dimensions for this TensorWrapper. + */ + size_t ndim() const noexcept { + if (tensor_ == nullptr) return 0; + return nvte_tensor_ndims(tensor_); + } + + /*! \brief Get the number of allocated elements in the tensor. This will return 0 for tensors + * with nullptr data even if the TensorWrapper has a non-zero shape. + * + * + * \return Number of elements in the tensor. + */ + size_t numel() const noexcept { + if (tensor_ == nullptr || this->dptr() == nullptr) return 0; + return nvte_tensor_numel(tensor_); + } + + /*! \brief Get the tensor's element size in bytes. + * + * \return Element size in bytes. + */ + size_t element_size() const noexcept { + if (tensor_ == nullptr) return 0; + return nvte_tensor_element_size(tensor_); + } + + /*! \brief Get the tensor's allocated size in bytes. This will return 0 for tensors with nullptr + * data even if the TensorWrapper has a non-zero shape and valid dtype. + * + * \return Total tensor size in bytes. + */ + size_t bytes() const noexcept { + if (tensor_ == nullptr || this->dptr() == nullptr) return 0; + return nvte_tensor_numel(tensor_) * nvte_tensor_element_size(tensor_); + } + /*! \brief Get the data type of this TensorWrapper. * * \return Data type of this TensorWrapper. @@ -317,6 +400,6 @@ class TensorWrapper { } // namespace transformer_engine -#endif +#endif // __cplusplus #endif // TRANSFORMER_ENGINE_TRANSFORMER_ENGINE_H_ diff --git a/transformer_engine/common/pycudnn.cpp b/transformer_engine/common/pycudnn.cpp deleted file mode 100644 index 7d06f332cb..0000000000 --- a/transformer_engine/common/pycudnn.cpp +++ /dev/null @@ -1,14 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -namespace cudnn_frontend { - -// This is needed to define the symbol `cudnn_dlhandle` -// When using the flag NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING -// to enable dynamic loading. -void *cudnn_dlhandle = nullptr; - -} // namespace cudnn_frontend diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 5cfab2f8cf..1a3b49f9fa 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -93,6 +93,31 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) { return ret; } +size_t nvte_tensor_ndim(const NVTETensor tensor) { + const auto &t = *reinterpret_cast(tensor); + return t.data.shape.size(); +} + +size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim) { + const auto &t = *reinterpret_cast(tensor); + NVTE_CHECK(dim >= 0 && dim < t.data.shape.size(), "Invalid dimension index: ", dim); + return t.data.shape[dim]; +} + +size_t nvte_tensor_numel(const NVTETensor tensor) { + const auto &t = *reinterpret_cast(tensor); + size_t numel = 1; + for (auto size : t.data.shape) { + numel *= size; + } + return numel; +} + +size_t nvte_tensor_element_size(const NVTETensor tensor) { + const auto &t = *reinterpret_cast(tensor); + return transformer_engine::typeToSize(t.data.dtype); +} + void *nvte_tensor_data(const NVTETensor tensor) { const auto &t = *reinterpret_cast(tensor); return t.data.dptr; diff --git a/transformer_engine/common/util/cuda_runtime.cpp b/transformer_engine/common/util/cuda_runtime.cpp index 5728ef557a..8d2e852988 100644 --- a/transformer_engine/common/util/cuda_runtime.cpp +++ b/transformer_engine/common/util/cuda_runtime.cpp @@ -12,6 +12,7 @@ #include "../common.h" #include "../util/cuda_driver.h" #include "../util/system.h" +#include "common/util/cuda_runtime.h" namespace transformer_engine { @@ -80,6 +81,31 @@ int sm_count(int device_id) { return cache[device_id]; } +bool supports_multicast(int device_id) { +#if CUDART_VERSION >= 12010 + // NOTE: This needs to be guarded at compile time because the + // CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED enum is not defined in earlier CUDA versions. + static std::vector cache(num_devices(), false); + static std::vector flags(num_devices()); + if (device_id < 0) { + device_id = current_device(); + } + NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID"); + auto init = [&]() { + CUdevice cudev; + NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &cudev, device_id); + int result; + NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGetAttribute, &result, + CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, cudev); + cache[device_id] = static_cast(result); + }; + std::call_once(flags[device_id], init); + return cache[device_id]; +#else + return false; +#endif +} + const std::string &include_directory(bool required) { static std::string path; diff --git a/transformer_engine/common/util/cuda_runtime.h b/transformer_engine/common/util/cuda_runtime.h index b6b4c41610..ea1ba84772 100644 --- a/transformer_engine/common/util/cuda_runtime.h +++ b/transformer_engine/common/util/cuda_runtime.h @@ -38,6 +38,14 @@ int sm_arch(int device_id = -1); */ int sm_count(int device_id = -1); +/* \brief CUDA Multicast support status for device + * + * \param[in] device_id CUDA device (default is current device) + * + * \return CUDA multicast support flag + */ +bool supports_multicast(int device_id = -1); + /* \brief Path to CUDA Toolkit headers * * The path can be configured by setting NVTE_CUDA_INCLUDE_DIR in the diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h new file mode 100644 index 0000000000..432ac815ec --- /dev/null +++ b/transformer_engine/common/util/pybind_helper.h @@ -0,0 +1,79 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_PYBIND_HELPER_H_ +#define TRANSFORMER_ENGINE_COMMON_UTIL_PYBIND_HELPER_H_ + +#include +#include +#include +#include + +#include "cuda_runtime.h" + +#define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ + pybind11::enum_(m, "DType") \ + .value("kByte", transformer_engine::DType::kByte) \ + .value("kInt32", transformer_engine::DType::kInt32) \ + .value("kFloat32", transformer_engine::DType::kFloat32) \ + .value("kFloat16", transformer_engine::DType::kFloat16) \ + .value("kBFloat16", transformer_engine::DType::kBFloat16) \ + .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ + .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ + pybind11::enum_(m, "NVTE_Bias_Type") \ + .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ + .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ + .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) \ + .value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI); \ + pybind11::enum_(m, "NVTE_Mask_Type") \ + .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) \ + .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) \ + .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) \ + .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) \ + .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ + .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ + NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ + pybind11::enum_(m, "NVTE_QKV_Layout") \ + .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ + .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ + .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) \ + .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) \ + .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) \ + .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) \ + .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) \ + .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) \ + .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) \ + .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) \ + .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) \ + .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \ + .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \ + .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \ + .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); \ + pybind11::enum_(m, "NVTE_Fused_Attn_Backend") \ + .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ + .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ + .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ + .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \ + pybind11::enum_(m, "CommOverlapType") \ + .value("RS", transformer_engine::CommOverlapType::RS) \ + .value("AG", transformer_engine::CommOverlapType::AG); \ + pybind11::enum_(m, "CommOverlapAlgo") \ + .value("BULK_OVERLAP_AG", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_AG) \ + .value("BULK_OVERLAP_RS", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_RS) \ + .value("SPLIT_PIPELINED_AG_P2P", \ + transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_P2P) \ + .value("SPLIT_PIPELINED_RS", transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS) \ + .value("SPLIT_PIPELINED_RS_P2P", \ + transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \ + .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ + .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ + .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \ + m.def("device_supports_multicast", &transformer_engine::cuda::supports_multicast, \ + py::call_guard(), py::arg("device_id") = -1); \ + m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ + py::call_guard()); + +#endif diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index b3b11bb9dd..3ecc9bcd75 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -79,6 +79,19 @@ class QKVFormat(Enum): THD = NVTE_QKV_Format.NVTE_THD +class CPStrategy(Enum): + """Defines the context parallel strategies of Jax fused attention. + + DEFAULT: Default strategy will choose automatically if context parallel axis is sharded. + ALL_GATHER: All-gather/reduce scatter implementation. + RING: Ring attention implementation (https://arxiv.org/abs/2310.01889). + """ + + DEFAULT = 0 + ALL_GATHER = 1 + RING = 2 + + def get_qkv_format(qkv_layout): """ Get qkv_format from qkv_layout @@ -190,7 +203,6 @@ def is_fused_attn_kernel_available( kv_max_seqlen, head_dim, window_size: Optional[Tuple[int, int]] = None, - is_context_parallel: bool = False, ): """ To check whether the fused attention kernel is supported @@ -215,11 +227,6 @@ def make_helper(attn_mask_type): if not make_helper(attn_mask_type).is_fused_attn_kernel_available(): return False - # For context parallel need to check additional masking types - if is_context_parallel and attn_mask_type == AttnMaskType.CAUSAL_MASK: - if not make_helper(AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK).is_fused_attn_kernel_available(): - return False - return True @@ -266,6 +273,7 @@ def fused_attn( dropout_probability: float, is_training: bool, window_size: Optional[Tuple[int, int]] = None, + context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT, context_parallel_causal_load_balanced: bool = False, context_parallel_axis: str = "", ): @@ -353,6 +361,7 @@ def fused_attn( is_training=is_training, max_segments_per_seq=1, window_size=window_size, + context_parallel_strategy=context_parallel_strategy, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_axis=context_parallel_axis, ) @@ -376,6 +385,7 @@ def fused_attn_thd( is_training: bool, max_segments_per_seq: int = 1, window_size: Optional[Tuple[int, int]] = None, + context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT, context_parallel_causal_load_balanced: bool = False, context_parallel_axis: str = "", ): @@ -476,6 +486,7 @@ def fused_attn_thd( is_training=is_training, max_segments_per_seq=max_segments_per_seq, window_size=window_size, + context_parallel_strategy=context_parallel_strategy, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_axis=context_parallel_axis, ) @@ -483,7 +494,7 @@ def fused_attn_thd( return output -@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16)) +@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17)) def _fused_attn( qkv: Tuple[jnp.ndarray, ...], bias: Optional[jnp.ndarray], @@ -500,6 +511,7 @@ def _fused_attn( is_training: bool, max_segments_per_seq: int, window_size: Optional[Tuple[int, int]], + context_parallel_strategy: CPStrategy, context_parallel_causal_load_balanced: bool, context_parallel_axis: str, ): @@ -519,6 +531,7 @@ def _fused_attn( is_training, max_segments_per_seq, window_size, + context_parallel_strategy, context_parallel_causal_load_balanced, context_parallel_axis, ) @@ -541,6 +554,7 @@ def _fused_attn_fwd_rule( is_training, max_segments_per_seq, window_size, + context_parallel_strategy, context_parallel_causal_load_balanced, context_parallel_axis, ): @@ -560,6 +574,7 @@ def _fused_attn_fwd_rule( is_training=is_training, max_segments_per_seq=max_segments_per_seq, window_size=window_size, + context_parallel_strategy=context_parallel_strategy, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_axis=context_parallel_axis, ) @@ -588,6 +603,7 @@ def _fused_attn_bwd_rule( is_training, max_segments_per_seq, window_size, + context_parallel_strategy, context_parallel_causal_load_balanced, context_parallel_axis, ctx, @@ -623,6 +639,7 @@ def _fused_attn_bwd_rule( is_training=is_training, max_segments_per_seq=max_segments_per_seq, window_size=window_size, + context_parallel_strategy=context_parallel_strategy, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_axis=context_parallel_axis, ) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 47483c67ea..44b396ad55 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -383,37 +383,43 @@ def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype, act_enum): assert amax_aval.dtype == jnp.float32 assert scale_aval.dtype == jnp.float32 assert scale_inv_aval.dtype == jnp.float32 - ir_x_type = ir.RankedTensorType(x.type) - ir_x_shape = ir_x_type.shape - ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) - ir_amax_type = ir.RankedTensorType(amax.type) - ir_amax_dtype = ir_amax_type.element_type - ir_amax_shape = ir_amax_type.shape - ir_scale_shape = ir_amax_shape - ir_scale_inv_shape = ir_amax_shape - - hidden_size = ir_x_shape[-1] - batch_shape = ir_x_shape[:-2] - batch_size = reduce(operator.mul, batch_shape) - out_shape = batch_shape + [hidden_size] - out_types = [ - ir.RankedTensorType.get(out_shape, ir_out_dtype), - ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), - ] - operands = [x, amax, scale, scale_inv] - operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - opaque = transformer_engine_jax.pack_common_descriptor( - (batch_size, hidden_size), - jax_dtype_to_te_dtype(x_aval.dtype), - jax_dtype_to_te_dtype(out_dtype), - act_enum, - ) + if is_ffi_enabled(): + name = "te_act_lu_fp8_ffi" + out = ffi.ffi_lowering(name, operand_output_aliases={1: 1})( + ctx, x, amax, scale, scale_inv, act_enum=act_enum + ) + else: + ir_x_type = ir.RankedTensorType(x.type) + ir_x_shape = ir_x_type.shape + ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) + ir_amax_type = ir.RankedTensorType(amax.type) + ir_amax_dtype = ir_amax_type.element_type + ir_amax_shape = ir_amax_type.shape + ir_scale_shape = ir_amax_shape + ir_scale_inv_shape = ir_amax_shape - out = custom_caller( - ActLuFp8Primitive.name, args, opaque, False, operand_output_aliases={1: 1} - ) + hidden_size = ir_x_shape[-1] + batch_shape = ir_x_shape[:-2] + batch_size = reduce(operator.mul, batch_shape) + out_shape = batch_shape + [hidden_size] + out_types = [ + ir.RankedTensorType.get(out_shape, ir_out_dtype), + ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), + ] + operands = [x, amax, scale, scale_inv] + operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + opaque = transformer_engine_jax.pack_common_descriptor( + (batch_size, hidden_size), + jax_dtype_to_te_dtype(x_aval.dtype), + jax_dtype_to_te_dtype(out_dtype), + act_enum, + ) + + out = custom_caller( + ActLuFp8Primitive.name, args, opaque, False, operand_output_aliases={1: 1} + ) return out diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 7246e961bd..6591861057 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -15,6 +15,9 @@ from jax.interpreters import mlir from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding +from jax.extend import ffi + +from transformer_engine.jax.attention import CPStrategy from transformer_engine import transformer_engine_jax from transformer_engine.transformer_engine_jax import ( @@ -33,6 +36,8 @@ te_dtype_to_jax_dtype, get_padded_spec, get_cudnn_version, + is_ffi_enabled, + get_xla_flag, ) from ..sharding import ( global_mesh_resource, @@ -275,7 +280,16 @@ def abstract( softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen) softmax_dtype = q_dtype elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: - softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, config.max_segments_per_seq) + # cuDNN 9.6 reduces the required softmax shape + if get_cudnn_version() >= (9, 6, 0): + softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1) + else: + softmax_shape = ( + *batch_shape, + attn_heads, + q_max_seqlen, + config.max_segments_per_seq, + ) softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) else: raise ValueError(f"Unsupported {backend=}") @@ -352,14 +366,6 @@ def lowering( """ Fused attention fwd lowering rules """ - operands = [q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, q_seq_offsets, k_seq_offsets, seed] - operand_shapes = map(lambda x: x.type.shape, operands) - out_types = [ - ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) - for output in ctx.avals_out - ] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( @@ -374,33 +380,84 @@ def lowering( *bias_batch_shape, bias_heads, _, _ = bias_aval.shape bias_batch = reduce(operator.mul, bias_batch_shape) - wkspace_aval = ctx.avals_out[-1] - - opaque = transformer_engine_jax.pack_fused_attn_descriptor( - input_batch, - bias_batch, - q_max_seqlen, - kv_max_seqlen, - attn_heads, - num_gqa_groups, - bias_heads, - head_dim, - config.max_segments_per_seq, - wkspace_aval.size, - config.scaling_factor, - config.dropout_probability, - config.attn_bias_type, - config.attn_mask_type, - config.qkv_layout, - jax_dtype_to_te_dtype(q_aval.dtype), - jax_dtype_to_te_dtype(wkspace_aval.dtype), - config.is_training, - not FusedAttnHelper.is_non_deterministic_allowed(), - config.window_size[0], - config.window_size[1], - ) + if is_ffi_enabled(): + name = "te_fused_attn_forward_ffi" + out = ffi.ffi_lowering(name)( + ctx, + q, + k, + v, + bias, + q_cu_seqlen, + kv_cu_seqlen, + q_seq_offsets, + k_seq_offsets, + seed, + input_batch=input_batch, + bias_batch=bias_batch, + q_max_seqlen=q_max_seqlen, + kv_max_seqlen=kv_max_seqlen, + attn_heads=attn_heads, + num_gqa_groups=num_gqa_groups, + bias_heads=bias_heads, + head_dim=head_dim, + max_segments_per_seq=config.max_segments_per_seq, + scaling_factor=float(config.scaling_factor), + dropout_probability=float(config.dropout_probability), + bias_type=int(config.attn_bias_type), + mask_type=int(config.attn_mask_type), + qkv_layout=int(config.qkv_layout), + is_training=config.is_training, + deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), + window_size_left=config.window_size[0], + window_size_right=config.window_size[1], + ) + else: + operands = [ + q, + k, + v, + bias, + q_cu_seqlen, + kv_cu_seqlen, + q_seq_offsets, + k_seq_offsets, + seed, + ] + operand_shapes = map(lambda x: x.type.shape, operands) + out_types = [ + ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) + for output in ctx.avals_out + ] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + wkspace_aval = ctx.avals_out[-1] + + opaque = transformer_engine_jax.pack_fused_attn_descriptor( + input_batch, + bias_batch, + q_max_seqlen, + kv_max_seqlen, + attn_heads, + num_gqa_groups, + bias_heads, + head_dim, + config.max_segments_per_seq, + wkspace_aval.size, + config.scaling_factor, + config.dropout_probability, + config.attn_bias_type, + config.attn_mask_type, + config.qkv_layout, + jax_dtype_to_te_dtype(q_aval.dtype), + jax_dtype_to_te_dtype(wkspace_aval.dtype), + config.is_training, + not FusedAttnHelper.is_non_deterministic_allowed(), + config.window_size[0], + config.window_size[1], + ) - out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False) + out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False) return out @@ -668,28 +725,6 @@ def lowering( """ Fused attention bwd lowering rules """ - operands = [ - q, - k, - v, - bias, - softmax_aux, - rng_state, - output, - doutput, - q_cu_seqlen, - kv_cu_seqlen, - q_seq_offsets, - k_seq_offsets, - ] - operand_shapes = map(lambda x: x.type.shape, operands) - out_types = [ - ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) - for output in ctx.avals_out - ] - - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( @@ -704,33 +739,90 @@ def lowering( *bias_batch_shape, bias_heads, _, _ = bias_aval.shape bias_batch = reduce(operator.mul, bias_batch_shape) - wkspace_aval = ctx.avals_out[-1] - - opaque = transformer_engine_jax.pack_fused_attn_descriptor( - input_batch, - bias_batch, - q_max_seqlen, - kv_max_seqlen, - attn_heads, - num_gqa_groups, - bias_heads, - head_dim, - config.max_segments_per_seq, - wkspace_aval.size, - config.scaling_factor, - config.dropout_probability, - config.attn_bias_type, - config.attn_mask_type, - config.qkv_layout, - jax_dtype_to_te_dtype(q_aval.dtype), - jax_dtype_to_te_dtype(wkspace_aval.dtype), - config.is_training, - not FusedAttnHelper.is_non_deterministic_allowed(), - config.window_size[0], - config.window_size[1], - ) + if is_ffi_enabled(): + name = "te_fused_attn_backward_ffi" + out = ffi.ffi_lowering(name)( + ctx, + q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_cu_seqlen, + kv_cu_seqlen, + q_seq_offsets, + k_seq_offsets, + input_batch=input_batch, + bias_batch=bias_batch, + q_max_seqlen=q_max_seqlen, + kv_max_seqlen=kv_max_seqlen, + attn_heads=attn_heads, + num_gqa_groups=num_gqa_groups, + bias_heads=bias_heads, + head_dim=head_dim, + max_segments_per_seq=config.max_segments_per_seq, + scaling_factor=float(config.scaling_factor), + dropout_probability=float(config.dropout_probability), + bias_type=int(config.attn_bias_type), + mask_type=int(config.attn_mask_type), + qkv_layout=int(config.qkv_layout), + is_training=config.is_training, + deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), + window_size_left=config.window_size[0], + window_size_right=config.window_size[1], + ) + else: + operands = [ + q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_cu_seqlen, + kv_cu_seqlen, + q_seq_offsets, + k_seq_offsets, + ] + operand_shapes = map(lambda x: x.type.shape, operands) + out_types = [ + ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) + for output in ctx.avals_out + ] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + wkspace_aval = ctx.avals_out[-1] + + opaque = transformer_engine_jax.pack_fused_attn_descriptor( + input_batch, + bias_batch, + q_max_seqlen, + kv_max_seqlen, + attn_heads, + num_gqa_groups, + bias_heads, + head_dim, + config.max_segments_per_seq, + wkspace_aval.size, + config.scaling_factor, + config.dropout_probability, + config.attn_bias_type, + config.attn_mask_type, + config.qkv_layout, + jax_dtype_to_te_dtype(q_aval.dtype), + jax_dtype_to_te_dtype(wkspace_aval.dtype), + config.is_training, + not FusedAttnHelper.is_non_deterministic_allowed(), + config.window_size[0], + config.window_size[1], + ) - out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False) + out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False) return out @@ -978,7 +1070,7 @@ def check_supported(self): if self.config.qkv_layout not in allowed_layouts: raise ValueError( f"{header} only supports layouts:" - f" {','.join([str(x) for x in allowed_layouts])} got: {self.config.qkv_layout}" + f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}" ) if self.config.attn_bias_type != NVTE_Bias_Type.NVTE_NO_BIAS: @@ -988,7 +1080,7 @@ def check_supported(self): if self.config.attn_mask_type not in allowed_masks: raise ValueError( f"{header} only supports masking types: " - f" {','.join([str(x) for x in allowed_masks])} got: {self.config.attn_mask_type}" + f" {','.join(map(str, allowed_masks))} got: {self.config.attn_mask_type}" ) if self.config.max_segments_per_seq != 1: @@ -1357,6 +1449,503 @@ def _cross_attn_bwd( register_primitive(FusedAttnCPWithAllGatherBwdPrimitive) +@dataclass(frozen=True) +class _FusedAttnCPWithP2PHelper: + """Helper class to assist with running the P2P ring strategy for CP attention.""" + + mesh: jax.sharding.Mesh + config: _FusedAttnConfig + + @staticmethod + def use_scanloop(): + """Returns true if the implementation will use a scan loop for iteration.""" + use_scan = bool(int(os.getenv("NVTE_FUSED_RING_ATTENTION_USE_SCAN", "1"))) + + # nvbug(4675071): Disable the HLO verifier for channel ID checks. + # A WAR was added to XLA: https://github.com/openxla/xla/pull/16779 + def truthy(val): + return val.lower() in ["1", "true"] + + x = use_scan and get_xla_flag( + "--xla_experimental_ignore_channel_id", default=False, cast=truthy + ) + return x + + def check_supported(self): + """Checks if the context parallel implementation is supported by the given arguments.""" + header = "Context parallel fused ring attention" + + allowed_layouts = [NVTE_QKV_Layout.NVTE_BSHD_BS2HD, NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD] + if self.config.qkv_layout not in allowed_layouts: + raise ValueError( + f"{header} only supports layouts:" + f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}" + ) + + if self.config.attn_bias_type != NVTE_Bias_Type.NVTE_NO_BIAS: + raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}") + + allowed_masks = [NVTE_Mask_Type.NVTE_NO_MASK, NVTE_Mask_Type.NVTE_CAUSAL_MASK] + if self.config.attn_mask_type not in allowed_masks: + raise ValueError( + f"{header} only supports masking types: " + f" {','.join(map(str, allowed_masks))} got: {self.config.attn_mask_type}" + ) + + if self.config.max_segments_per_seq != 1: + raise ValueError( + f"{header} only supports max_segments_per_seq == 1 got:" + f" {self.config.max_segments_per_seq}" + ) + + if self.config.dropout_probability != 0.0: + raise ValueError(f"{header} does not support dropout") + + # We want to encourage use of scan loop to minimize unrolling and ensure more + # predictable scheduling from XLA. The unrolled flavor will be supported but + # not the prefered implementation. + if not self.use_scanloop(): + warnings.warn( + "Scan loop is disabled for fused ring attention. To enable set" + " NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 in your environment and" + " add --xla_experimental_ignore_channel_id=true to XLA_FLAGS." + ) + + def get_step_config(self, attn_mask_type) -> _FusedAttnConfig: + """Returns a _FusedAttnConfig for single CP step call to fused attention.""" + return _FusedAttnConfig( + attn_bias_type=self.config.attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BS2HD, + scaling_factor=self.config.scaling_factor, + dropout_probability=self.config.dropout_probability, + is_training=self.config.is_training, + max_segments_per_seq=self.config.max_segments_per_seq, + window_size=self.config.window_size, + context_parallel_load_balanced=self.config.context_parallel_load_balanced, + cp_axis=self.config.cp_axis, + ) + + def stack_kv(self, k, v): + """Stacks k and v tensors if not stacked.""" + _not_used = jnp.zeros(0, dtype=k.dtype) + match self.config.qkv_layout: + case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: + return k + case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: + return jnp.stack([k, v], axis=2) + return _not_used + + def unstack_kv(self, kv): + """Un-stacks k and v tensors if not stacked.""" + _not_used = jnp.zeros(0, dtype=kv.dtype) + match self.config.qkv_layout: + case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: + return kv, _not_used + case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: + return jnp.unstack(kv, axis=2) + return _not_used, _not_used # fall through + + def permute_kv(self, kv, cp_perm): + """Permutes kv around the ring as described by cp_perm.""" + return lax_paral_op(kv, lax.ppermute, self.config.cp_axis, mesh=self.mesh, perm=cp_perm) + + def correct_softmax_aux(self, softmax_aux, softmax_aux_per_step): + """Apply soft max correction after an attention step.""" + max_scale = jnp.maximum(softmax_aux, softmax_aux_per_step) + min_scale = jnp.minimum(softmax_aux, softmax_aux_per_step) + new_softmax_aux = max_scale + jnp.log(1 + jnp.exp(min_scale - max_scale)) + return new_softmax_aux + + def adjust_seqlen(self, seqlen, max_seqlen, idx): + """Adjust the sequence length per step.""" + seqlen_of_curr_step = seqlen - max_seqlen * idx + seqlen_of_curr_step = jnp.where(seqlen_of_curr_step < 0, 0, seqlen_of_curr_step) + seqlen_per_step = jnp.where( + seqlen_of_curr_step < max_seqlen, seqlen_of_curr_step, max_seqlen + ) + return seqlen_per_step + + +class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive): + """ + Fused Ring Attention Forward Primitive + """ + + @staticmethod + def partition(config, mesh, arg_infos, result_infos): + is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 + assert ( + not is_context_parallel or config.window_size[0] == -1 + ), "Sliding window attention is not supported when context parallelism is enabled" + if not is_context_parallel: + return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos) + + helper = _FusedAttnCPWithP2PHelper(mesh, config) + helper.check_supported() + + out_sharding = result_infos[0].sharding + softmax_aux_sharding = result_infos[1].sharding + rng_state_sharding = seed_sharding = NamedSharding( + mesh, PartitionSpec(get_all_mesh_axes(), None) + ) + arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding]) + out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) + + def ring_attn_fwd_impl( + q, + k, + v, + bias, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + seed, + ): + _not_used = jnp.zeros(0, dtype=v.dtype) + + # Combine KV tensors if separate for better permute scheduling and performance. + # Eventually XLA should perform this automatically. + kv = helper.stack_kv(k, v) + + batch, q_max_seqlen, head, _ = q.shape + kv_max_seqlen = k.shape[1] + + cp_size = get_mesh_axis_size(config.cp_axis, mesh) + cp_rank = get_mesh_axis_rank(config.cp_axis, mesh) + cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)] + + output_per_steps = jnp.zeros((cp_size, *q.shape), dtype=q.dtype) + softmax_aux_per_steps = jnp.zeros( + (cp_size, batch, head, q_max_seqlen, 1), dtype=jnp.float32 + ) + softmax_aux = jnp.full((batch, head, q_max_seqlen, 1), -jnp.inf, dtype=jnp.float32) + + # RNG shape should be the shared shape. This is unused for ring attention as we do not + # support dropout currently. + rng_state_shape = (result_infos[2].shape[0] // mesh.size, *result_infos[2].shape[1:]) + rng_state = jnp.zeros(rng_state_shape).astype(result_infos[2].dtype) + + def scan_kv_block(idx, carry): + kv, softmax_aux, output_per_steps, softmax_aux_per_steps = carry + + # Send KV block to next step so we can overlap compute. + kv_next = helper.permute_kv(kv, cp_perm) + + def mask_compute(attn_mask_type): + q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) + kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) + output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl( + q, + kv, + _not_used, + bias, + q_seqlen_per_step, + kv_seqlen_per_step, + q_seq_offsets, + k_seq_offsets, + seed, + helper.get_step_config(attn_mask_type), + ) + return output_per_step, softmax_aux_per_step + + causal_mask_compute = partial(mask_compute, NVTE_Mask_Type.NVTE_CAUSAL_MASK) + no_mask_compute = partial(mask_compute, NVTE_Mask_Type.NVTE_NO_MASK) + + def half_kv_no_mask_compute(): + q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) + kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) // 2 + kv_part = lax.slice_in_dim(kv, 0, kv.shape[1] // 2, axis=1) + output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl( + q, + kv_part, + _not_used, + bias, + q_seqlen_per_step, + kv_seqlen_per_step, + q_seq_offsets, + k_seq_offsets, + seed, + config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK), + ) + return output_per_step, softmax_aux_per_step + + def half_q_no_mask_compute(): + q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) // 2 + kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) + q_part = lax.slice_in_dim(q, q_max_seqlen // 2, q_max_seqlen, axis=1) + output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl( + q_part, + kv, + _not_used, + bias, + q_seqlen_per_step, + kv_seqlen_per_step, + q_seq_offsets, + k_seq_offsets, + seed, + config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK), + ) + output_per_step = jnp.concat([jnp.zeros_like(q_part), output_per_step], axis=1) + softmax_aux_per_step = jnp.concat( + [ + jnp.full_like(softmax_aux_per_step, -jnp.inf), + softmax_aux_per_step, + ], + axis=2, + ) + return output_per_step, softmax_aux_per_step + + def skip_compute(): + output_per_step = jnp.zeros_like(q) + softmax_aux_per_step = jnp.full( + (batch, head, q.shape[1], 1), -jnp.inf, dtype=jnp.float32 + ) + return output_per_step, softmax_aux_per_step + + if config.attn_mask_type == NVTE_Mask_Type.NVTE_CAUSAL_MASK: + # This is for nested jax.lax.cond + def jax_cond_wrap(): + if config.context_parallel_load_balanced: + return lax.cond( + (idx <= cp_rank), half_kv_no_mask_compute, half_q_no_mask_compute + ) + return lax.cond((idx <= cp_rank), no_mask_compute, skip_compute) + + output_per_step, softmax_aux_per_step = lax.cond( + idx == 0, causal_mask_compute, jax_cond_wrap + ) + else: + output_per_step, softmax_aux_per_step = no_mask_compute() + + softmax_aux = helper.correct_softmax_aux(softmax_aux, softmax_aux_per_step) + output_per_steps = output_per_steps.at[idx].set(output_per_step) + softmax_aux_per_steps = softmax_aux_per_steps.at[idx].set(softmax_aux_per_step) + + return (kv_next, softmax_aux, output_per_steps, softmax_aux_per_steps) + + carry = (kv, softmax_aux, output_per_steps, softmax_aux_per_steps) + if helper.use_scanloop(): + carry = lax.fori_loop(0, cp_size, scan_kv_block, carry) + else: + for i in range(0, cp_size): + carry = scan_kv_block(i, carry) + (kv, softmax_aux, output_per_steps, softmax_aux_per_steps) = carry + + output = jnp.zeros(q.shape).astype(jnp.float32) + for idx in range(cp_size): + output = output + output_per_steps[idx].astype(jnp.float32) * jnp.exp( + softmax_aux_per_steps[idx] - softmax_aux + ).transpose(0, 2, 1, 3) + output = output.astype(q.dtype) + return output, softmax_aux, rng_state + + return mesh, ring_attn_fwd_impl, out_shardings, arg_shardings + + +register_primitive(FusedRingAttnFwdPrimitive) + + +class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive): + """ + Fused Ring Attention Backward Primitive + """ + + @staticmethod + def partition(config, mesh, arg_infos, result_infos): + is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 + assert ( + not is_context_parallel or config.window_size[0] == -1 + ), "Sliding window attention is not supported when context parallelism is enabled" + if not is_context_parallel: + return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos) + + del result_infos + q_spec = get_padded_spec(arg_infos[0]) + k_spec = get_padded_spec(arg_infos[1]) + v_spec = get_padded_spec(arg_infos[2]) + bias_spec = get_padded_spec(arg_infos[3]) + dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) + dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) + dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) + + helper = _FusedAttnCPWithP2PHelper(mesh, config) + helper.check_supported() + + def ring_attn_bwd_impl( + q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + ): + _not_used = jnp.zeros(0, dtype=output.dtype) + + # Combine KV tensors if separate for better permute scheduling and performance. + # Eventually XLA should perform this automatically. + kv = helper.stack_kv(k, v) + + q_max_seqlen = q.shape[1] + kv_max_seqlen = k.shape[1] + + cp_size = get_mesh_axis_size(config.cp_axis, mesh) + cp_rank = get_mesh_axis_rank(config.cp_axis, mesh) + cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)] + + dq = jnp.zeros_like(q) + dk_dv = helper.stack_kv(jnp.zeros_like(k), jnp.zeros_like(v)) + dbias = jnp.zeros_like(bias) + + def scan_kv_block(idx, carry): + + kv, dq, dk_dv, dbias = carry + + # Start communication that feeds the next iteraton. + # We further combine the tensors to improve overlap. + + kv_dk_dv = jnp.stack([kv, dk_dv]) + kv_dk_dv = helper.permute_kv(kv_dk_dv, cp_perm) + + def mask_compute(attn_mask_type): + q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) + kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) + dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl( + q, + kv, + _not_used, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen_per_step, + kv_seqlen_per_step, + q_seq_offsets, + k_seq_offsets, + config=helper.get_step_config(attn_mask_type), + ) + return dq_per_step, dk_dv_per_step, dbias_per_step + + causal_mask_compute = partial(mask_compute, NVTE_Mask_Type.NVTE_CAUSAL_MASK) + no_mask_compute = partial(mask_compute, NVTE_Mask_Type.NVTE_NO_MASK) + + def half_kv_no_mask_compute(): + q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) + kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) // 2 + kv_part = lax.slice_in_dim(kv, 0, kv_max_seqlen // 2, axis=1) + dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl( + q, + kv_part, + _not_used, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen_per_step, + kv_seqlen_per_step, + q_seq_offsets, + k_seq_offsets, + config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK), + ) + dk_dv_per_step = jnp.concat( + [dk_dv_per_step, jnp.zeros_like(dk_dv_per_step)], axis=1 + ) + return dq_per_step, dk_dv_per_step, dbias_per_step + + def half_q_no_mask_compute(): + q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) // 2 + kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) + + q_part = lax.slice_in_dim(q, q_max_seqlen // 2, q_max_seqlen, axis=1) + doutput_part = lax.slice_in_dim( + doutput, q_max_seqlen // 2, q_max_seqlen, axis=1 + ) + output_part = lax.slice_in_dim(output, q_max_seqlen // 2, q_max_seqlen, axis=1) + + softmax_aux_part = lax.slice_in_dim( + softmax_aux, q_max_seqlen // 2, q_max_seqlen, axis=2 + ) + + dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl( + q_part, + kv, + _not_used, + bias, + softmax_aux_part, + rng_state, + output_part, + doutput_part, + q_seqlen_per_step, + kv_seqlen_per_step, + q_seq_offsets, + k_seq_offsets, + config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK), + ) + dq_per_step = jnp.concat([jnp.zeros_like(dq_per_step), dq_per_step], axis=1) + return dq_per_step, dk_dv_per_step, dbias_per_step + + def skip_compute(): + return jnp.zeros_like(q), jnp.zeros_like(kv), jnp.zeros_like(bias) + + if config.attn_mask_type == NVTE_Mask_Type.NVTE_CAUSAL_MASK: + # This is for nested jax.lax.cond + def jax_cond_wrap(): + if config.context_parallel_load_balanced: + return lax.cond( + (idx <= cp_rank), half_kv_no_mask_compute, half_q_no_mask_compute + ) + return lax.cond((idx <= cp_rank), no_mask_compute, skip_compute) + + dq_per_step, dk_dv_per_step, dbias_per_step = lax.cond( + idx == 0, causal_mask_compute, jax_cond_wrap + ) + else: + dq_per_step, dk_dv_per_step, dbias_per_step = no_mask_compute() + + kv_next, dk_dv = jnp.unstack(kv_dk_dv) + dq = dq + dq_per_step + dk_dv = dk_dv + dk_dv_per_step + if config.attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: + dbias = dbias + dbias_per_step + + return (kv_next, dq, dk_dv, dbias) + + carry = (kv, dq, dk_dv, dbias) + if helper.use_scanloop(): + carry = lax.fori_loop(0, cp_size, scan_kv_block, carry) + else: + for i in range(0, cp_size): + carry = scan_kv_block(i, carry) + (kv, dq, dk_dv, dbias) = carry + + # Final permute to put gradients back to their final resting place. + dk_dv = helper.permute_kv(dk_dv, cp_perm) + + global_dbias = dbias + if config.attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: + global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh) + + dk, dv = helper.unstack_kv(dk_dv) + return dq, dk, dv, global_dbias + + return mesh, ring_attn_bwd_impl, out_shardings, arg_shardings + + +register_primitive(FusedRingAttnBwdPrimitive) + + def _maybe_context_parallel_axis(cp_axis: str): if not cp_axis: gmr = global_mesh_resource() @@ -1383,6 +1972,7 @@ def fused_attn_fwd( is_training: bool, max_segments_per_seq: int, window_size: Optional[Tuple[int, int]] = None, + context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT, context_parallel_causal_load_balanced: bool = False, context_parallel_axis: str = "", ) -> jnp.ndarray: @@ -1465,7 +2055,14 @@ def fused_attn_fwd( cp_axis=_maybe_context_parallel_axis(context_parallel_axis), ) - return FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive.bind( + primative = None + match context_parallel_strategy: + case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER: + primative = FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive + case CPStrategy.RING: + primative = FusedRingAttnFwdPrimitive.outer_primitive + + return primative.bind( *qkv_for_primitive, bias, q_seqlen, @@ -1496,6 +2093,7 @@ def fused_attn_bwd( is_training: bool, max_segments_per_seq: int, window_size: Optional[Tuple[int, int]] = None, + context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT, context_parallel_causal_load_balanced: bool = False, context_parallel_axis: str = "", ): @@ -1582,7 +2180,14 @@ def fused_attn_bwd( cp_axis=_maybe_context_parallel_axis(context_parallel_axis), ) - *qkv_grads, bias_grad = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive.bind( + primative = None + match context_parallel_strategy: + case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER: + primative = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive + case CPStrategy.RING: + primative = FusedRingAttnBwdPrimitive.outer_primitive + + *qkv_grads, bias_grad = primative.bind( *qkv_for_primitive, bias, softmax_aux, diff --git a/transformer_engine/jax/cpp_extensions/custom_call.py b/transformer_engine/jax/cpp_extensions/custom_call.py index 8e58ed3bed..1075030a0d 100644 --- a/transformer_engine/jax/cpp_extensions/custom_call.py +++ b/transformer_engine/jax/cpp_extensions/custom_call.py @@ -5,8 +5,8 @@ from dataclasses import dataclass from enum import IntEnum -from jax.lib import xla_client from jax.interpreters import mlir +import jax.extend as jex from transformer_engine import transformer_engine_jax @@ -30,12 +30,11 @@ class CustomCallAPIVersion(IntEnum): for _name, _value in transformer_engine_jax.registrations().items(): if _name.endswith("_ffi"): if is_ffi_enabled(): - # COMMAND_BUFFER_COMPATIBLE i.e. cudaGraph enabled by default - xla_client.register_custom_call_target( + jex.ffi.register_ffi_target( _name, _value, platform="CUDA", api_version=CustomCallAPIVersion.FFI.value ) else: - xla_client.register_custom_call_target( + jex.ffi.register_ffi_target( _name, _value, platform="CUDA", api_version=CustomCallAPIVersion.OPAQUE.value ) diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index 58b8db4c88..1f13484b98 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -162,8 +162,30 @@ def is_ffi_enabled(): """ Helper function checking if XLA Custom Call with FFI is enabled """ - is_supported = jax_version_meet_requirement("0.4.31") + is_supported = jax_version_meet_requirement("0.4.35") # New APIs with FFI are enabled by default is_enabled = int(os.getenv("NVTE_JAX_WITH_FFI", "1")) assert is_enabled in (0, 1), "Invalid NVTE_JAX_WITH_FFI value" return is_supported and is_enabled + + +def get_xla_flag(flag: str, default=None, cast=str): + """ + Returns the value of a flag/option in XLA_FLAGS environment variable if present or returns the default value. + """ + xla_flags = [] + if xla_flags_env := os.getenv("XLA_FLAGS"): + xla_flags.extend(xla_flags_env.split()) + + for flag_i in sorted(xla_flags): + if "=" in flag_i: + # option like --xla_abc=foo + name, val = flag_i.split("=", 2) + if name == flag: + return val if cast is None else cast(val) + else: + # flag like --xla_enable_foo + name, val = flag_i, None + if name == flag: + return True + return default diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index e85f28a06a..fd6cc09de9 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -13,6 +13,7 @@ from jax.interpreters import mlir from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding +from jax.extend import ffi from transformer_engine import transformer_engine_jax from transformer_engine.transformer_engine_jax import DType as TEDType @@ -25,6 +26,7 @@ jax_dtype_to_te_dtype, jax_dtype_to_ir_dtype, te_dtype_to_jax_dtype, + is_ffi_enabled, ) from .quantization import _jax_cast_fp8 from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp @@ -125,51 +127,68 @@ def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon): assert g_type == b_type assert g_shape == b_shape - # Output shape is same as the input shape, but the output type is same as the weight type. - # See ln_api.cpp - output_type = g_type.element_type - ir_mu_dtype = ir.F32Type.get() - ir_rsigma_dtype = ir.F32Type.get() - - out_shape = x_shape - hidden_size = reduce(operator.mul, g_shape) - batch_shape = out_shape[:-1] - batch_size = reduce(operator.mul, x_shape) // hidden_size - - wkspace_aval, barrier_aval = ctx.avals_out[-2:] - - out_types = [ - ir.RankedTensorType.get(out_shape, output_type), - ir.RankedTensorType.get(batch_shape, ir_mu_dtype), - ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype), - ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), - ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)), - ] - operands = [x, gamma, beta] - operand_shapes = [x_shape, g_shape, b_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - sm_margin = get_forward_sm_margin() - - opaque = transformer_engine_jax.pack_norm_descriptor( - batch_size, - hidden_size, - wkspace_aval.size, - barrier_aval.size, - (0,), # no dgamma_part in FWD pass - (0,), # no dbeta_part in BWD pass - jax_dtype_to_te_dtype(x_aval.dtype), - jax_dtype_to_te_dtype(gamma_aval.dtype), - jax_dtype_to_te_dtype(wkspace_aval.dtype), - jax_dtype_to_te_dtype(barrier_aval.dtype), - TEDType.kByte, # dummy dgamma_part te_dtype - TEDType.kByte, # dummy dbeta_part te_dtype - zero_centered_gamma, - epsilon, - sm_margin, - ) + if is_ffi_enabled(): + name = "te_layernorm_forward_ffi" + sm_margin = get_forward_sm_margin() + out = ffi.ffi_lowering(name)( + ctx, + x, + gamma, + beta, + zero_centered_gamma=zero_centered_gamma, + eps=epsilon, + sm_margin=sm_margin, + ) + else: + # Output shape is same as the input shape, but the output type is same as the weight type. + # See ln_api.cpp + output_type = g_type.element_type + ir_mu_dtype = ir.F32Type.get() + ir_rsigma_dtype = ir.F32Type.get() + + out_shape = x_shape + hidden_size = reduce(operator.mul, g_shape) + batch_shape = out_shape[:-1] + batch_size = reduce(operator.mul, x_shape) // hidden_size + + wkspace_aval, barrier_aval = ctx.avals_out[-2:] + + out_types = [ + ir.RankedTensorType.get(out_shape, output_type), + ir.RankedTensorType.get(batch_shape, ir_mu_dtype), + ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype), + ir.RankedTensorType.get( + wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype) + ), + ir.RankedTensorType.get( + barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype) + ), + ] + operands = [x, gamma, beta] + operand_shapes = [x_shape, g_shape, b_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + sm_margin = get_forward_sm_margin() + + opaque = transformer_engine_jax.pack_norm_descriptor( + batch_size, + hidden_size, + wkspace_aval.size, + barrier_aval.size, + (0,), # no dgamma_part in FWD pass + (0,), # no dbeta_part in BWD pass + jax_dtype_to_te_dtype(x_aval.dtype), + jax_dtype_to_te_dtype(gamma_aval.dtype), + jax_dtype_to_te_dtype(wkspace_aval.dtype), + jax_dtype_to_te_dtype(barrier_aval.dtype), + TEDType.kByte, # dummy dgamma_part te_dtype + TEDType.kByte, # dummy dbeta_part te_dtype + zero_centered_gamma, + epsilon, + sm_margin, + ) - out = custom_caller(LayerNormFwdPrimitive.name, args, opaque, False) + out = custom_caller(LayerNormFwdPrimitive.name, args, opaque, False) return out @@ -418,44 +437,59 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, zero_centered_gamma, epsilon): assert g_type == b_type assert g_shape == b_shape - dz_shape = ir.RankedTensorType(dz.type).shape - mu_shape = ir.RankedTensorType(mu.type).shape - rsigma_shape = ir.RankedTensorType(rsigma.type).shape - - hidden_size = reduce(operator.mul, g_shape) - batch_size = reduce(operator.mul, x_shape) // hidden_size - - out_types = [ - ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) - for output in ctx.avals_out - ] - - operands = [dz, mu, rsigma, x, gamma] - operand_shapes = [dz_shape, mu_shape, rsigma_shape, x_shape, g_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - sm_margin = get_backward_sm_margin() - - wkspace_aval, barrier_aval, dgamma_part_aval, dbeta_part_aval = ctx.avals_out[-4:] - opaque = transformer_engine_jax.pack_norm_descriptor( - batch_size, - hidden_size, - wkspace_aval.size, - barrier_aval.size, - dgamma_part_aval.shape, - dbeta_part_aval.shape, - jax_dtype_to_te_dtype(x_aval.dtype), - jax_dtype_to_te_dtype(gamma_aval.dtype), - jax_dtype_to_te_dtype(wkspace_aval.dtype), - jax_dtype_to_te_dtype(barrier_aval.dtype), - jax_dtype_to_te_dtype(dgamma_part_aval.dtype), - jax_dtype_to_te_dtype(dbeta_part_aval.dtype), - zero_centered_gamma, - epsilon, - sm_margin, - ) + if is_ffi_enabled(): + name = "te_layernorm_backward_ffi" + sm_margin = get_backward_sm_margin() + out = ffi.ffi_lowering(name)( + ctx, + dz, + x, + mu, + rsigma, + gamma, + zero_centered_gamma=zero_centered_gamma, + eps=epsilon, + sm_margin=sm_margin, + ) + else: + dz_shape = ir.RankedTensorType(dz.type).shape + mu_shape = ir.RankedTensorType(mu.type).shape + rsigma_shape = ir.RankedTensorType(rsigma.type).shape + + hidden_size = reduce(operator.mul, g_shape) + batch_size = reduce(operator.mul, x_shape) // hidden_size + + out_types = [ + ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) + for output in ctx.avals_out + ] + + operands = [dz, mu, rsigma, x, gamma] + operand_shapes = [dz_shape, mu_shape, rsigma_shape, x_shape, g_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + sm_margin = get_backward_sm_margin() + + wkspace_aval, barrier_aval, dgamma_part_aval, dbeta_part_aval = ctx.avals_out[-4:] + opaque = transformer_engine_jax.pack_norm_descriptor( + batch_size, + hidden_size, + wkspace_aval.size, + barrier_aval.size, + dgamma_part_aval.shape, + dbeta_part_aval.shape, + jax_dtype_to_te_dtype(x_aval.dtype), + jax_dtype_to_te_dtype(gamma_aval.dtype), + jax_dtype_to_te_dtype(wkspace_aval.dtype), + jax_dtype_to_te_dtype(barrier_aval.dtype), + jax_dtype_to_te_dtype(dgamma_part_aval.dtype), + jax_dtype_to_te_dtype(dbeta_part_aval.dtype), + zero_centered_gamma, + epsilon, + sm_margin, + ) - out = custom_caller(LayerNormBwdPrimitive.name, args, opaque, False) + out = custom_caller(LayerNormBwdPrimitive.name, args, opaque, False) return out @@ -629,51 +663,68 @@ def lowering(ctx, x, gamma, *, epsilon): """ RMSNorm fwd lowering rules """ - x_aval, gamma_aval = ctx.avals_in - x_type = ir.RankedTensorType(x.type) - x_shape = x_type.shape - g_type = ir.RankedTensorType(gamma.type) - g_shape = g_type.shape - rsigma_element_type = ir.F32Type.get() - - out_shape = x_shape - hidden_size = reduce(operator.mul, g_shape) - batch_shape = out_shape[:-1] - batch_size = reduce(operator.mul, x_shape) // hidden_size - - wkspace_aval, barrier_aval = ctx.avals_out[-2:] - - out_types = [ - ir.RankedTensorType.get(out_shape, x_type.element_type), - ir.RankedTensorType.get(batch_shape, rsigma_element_type), - ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), - ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)), - ] - operands = [x, gamma] - operand_shapes = [x_shape, g_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - sm_margin = get_forward_sm_margin() - - opaque = transformer_engine_jax.pack_norm_descriptor( - batch_size, - hidden_size, - wkspace_aval.size, - barrier_aval.size, - (0,), # no dgamma_part in FWD pass - (0,), # no dbeta_part in BWD pass - jax_dtype_to_te_dtype(x_aval.dtype), - jax_dtype_to_te_dtype(gamma_aval.dtype), - jax_dtype_to_te_dtype(wkspace_aval.dtype), - jax_dtype_to_te_dtype(barrier_aval.dtype), - TEDType.kByte, # dummy dgamma_part te_dtype - TEDType.kByte, # dummy dbeta_part te_dtype - False, # RMSNorm doesn't support zero_centered_gamma - epsilon, - sm_margin, - ) + if is_ffi_enabled(): + name = "te_rmsnorm_forward_ffi" + sm_margin = get_forward_sm_margin() + zero_centered_gamma = False # RMSNorm doesn't support zero_centered_gamma + out = ffi.ffi_lowering(name)( + ctx, + x, + gamma, + zero_centered_gamma=zero_centered_gamma, + eps=epsilon, + sm_margin=sm_margin, + ) + else: + x_aval, gamma_aval = ctx.avals_in + x_type = ir.RankedTensorType(x.type) + x_shape = x_type.shape + g_type = ir.RankedTensorType(gamma.type) + g_shape = g_type.shape + rsigma_element_type = ir.F32Type.get() + + out_shape = x_shape + hidden_size = reduce(operator.mul, g_shape) + batch_shape = out_shape[:-1] + batch_size = reduce(operator.mul, x_shape) // hidden_size + + wkspace_aval, barrier_aval = ctx.avals_out[-2:] + + out_types = [ + ir.RankedTensorType.get(out_shape, x_type.element_type), + ir.RankedTensorType.get(batch_shape, rsigma_element_type), + ir.RankedTensorType.get( + wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype) + ), + ir.RankedTensorType.get( + barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype) + ), + ] + operands = [x, gamma] + operand_shapes = [x_shape, g_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + sm_margin = get_forward_sm_margin() + + opaque = transformer_engine_jax.pack_norm_descriptor( + batch_size, + hidden_size, + wkspace_aval.size, + barrier_aval.size, + (0,), # no dgamma_part in FWD pass + (0,), # no dbeta_part in BWD pass + jax_dtype_to_te_dtype(x_aval.dtype), + jax_dtype_to_te_dtype(gamma_aval.dtype), + jax_dtype_to_te_dtype(wkspace_aval.dtype), + jax_dtype_to_te_dtype(barrier_aval.dtype), + TEDType.kByte, # dummy dgamma_part te_dtype + TEDType.kByte, # dummy dbeta_part te_dtype + False, # RMSNorm doesn't support zero_centered_gamma + epsilon, + sm_margin, + ) - out = custom_caller(RmsNormFwdPrimitive.name, args, opaque, False) + out = custom_caller(RmsNormFwdPrimitive.name, args, opaque, False) return out @@ -819,53 +870,72 @@ def lowering(ctx, dz, x, rsigma, gamma, *, epsilon): """ RMSNorm bwd lowering rules """ - _, x_aval, _, gamma_aval = ctx.avals_in - x_type = ir.RankedTensorType(x.type) - x_shape = x_type.shape - g_type = ir.RankedTensorType(gamma.type) - g_shape = g_type.shape - dz_shape = ir.RankedTensorType(dz.type).shape - rsigma_shape = ir.RankedTensorType(rsigma.type).shape - - hidden_size = reduce(operator.mul, g_shape) - batch_size = reduce(operator.mul, x_shape) // hidden_size - - wkspace_aval, barrier_aval, dgamma_part_aval = ctx.avals_out[-3:] - - out_types = [ - ir.RankedTensorType.get(x_shape, x_type.element_type), - ir.RankedTensorType.get(g_shape, g_type.element_type), - ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), - ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)), - ir.RankedTensorType.get( - dgamma_part_aval.shape, jax_dtype_to_ir_dtype(dgamma_part_aval.dtype) - ), - ] - operands = [dz, rsigma, x, gamma] - operand_shapes = [dz_shape, rsigma_shape, x_shape, g_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - sm_margin = get_backward_sm_margin() - - opaque = transformer_engine_jax.pack_norm_descriptor( - batch_size, - hidden_size, - wkspace_aval.size, - barrier_aval.size, - dgamma_part_aval.shape, - (0,), # no dbeta_part for RMSnorm - jax_dtype_to_te_dtype(x_aval.dtype), - jax_dtype_to_te_dtype(gamma_aval.dtype), - jax_dtype_to_te_dtype(wkspace_aval.dtype), - jax_dtype_to_te_dtype(barrier_aval.dtype), - jax_dtype_to_te_dtype(dgamma_part_aval.dtype), - TEDType.kByte, # dummy dbeta_part te_dtype - False, # RMSNorm doesn't support zero_centered_gamma - epsilon, - sm_margin, - ) + if is_ffi_enabled(): + name = "te_rmsnorm_backward_ffi" + sm_margin = get_backward_sm_margin() + zero_centered_gamma = False # RMSNorm doesn't support zero_centered_gamma + out = ffi.ffi_lowering(name)( + ctx, + dz, + x, + rsigma, + gamma, + zero_centered_gamma=zero_centered_gamma, + eps=epsilon, + sm_margin=sm_margin, + ) + else: + _, x_aval, _, gamma_aval = ctx.avals_in + x_type = ir.RankedTensorType(x.type) + x_shape = x_type.shape + g_type = ir.RankedTensorType(gamma.type) + g_shape = g_type.shape + dz_shape = ir.RankedTensorType(dz.type).shape + rsigma_shape = ir.RankedTensorType(rsigma.type).shape + + hidden_size = reduce(operator.mul, g_shape) + batch_size = reduce(operator.mul, x_shape) // hidden_size + + wkspace_aval, barrier_aval, dgamma_part_aval = ctx.avals_out[-3:] + + out_types = [ + ir.RankedTensorType.get(x_shape, x_type.element_type), + ir.RankedTensorType.get(g_shape, g_type.element_type), + ir.RankedTensorType.get( + wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype) + ), + ir.RankedTensorType.get( + barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype) + ), + ir.RankedTensorType.get( + dgamma_part_aval.shape, jax_dtype_to_ir_dtype(dgamma_part_aval.dtype) + ), + ] + operands = [dz, rsigma, x, gamma] + operand_shapes = [dz_shape, rsigma_shape, x_shape, g_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + sm_margin = get_backward_sm_margin() + + opaque = transformer_engine_jax.pack_norm_descriptor( + batch_size, + hidden_size, + wkspace_aval.size, + barrier_aval.size, + dgamma_part_aval.shape, + (0,), # no dbeta_part for RMSnorm + jax_dtype_to_te_dtype(x_aval.dtype), + jax_dtype_to_te_dtype(gamma_aval.dtype), + jax_dtype_to_te_dtype(wkspace_aval.dtype), + jax_dtype_to_te_dtype(barrier_aval.dtype), + jax_dtype_to_te_dtype(dgamma_part_aval.dtype), + TEDType.kByte, # dummy dbeta_part te_dtype + False, # RMSNorm doesn't support zero_centered_gamma + epsilon, + sm_margin, + ) - out = custom_caller(RmsNormBwdPrimitive.name, args, opaque, False) + out = custom_caller(RmsNormBwdPrimitive.name, args, opaque, False) return out @@ -1058,64 +1128,84 @@ def lowering( assert g_type == b_type assert g_shape == b_shape - ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) - ir_mu_dtype = ir.F32Type.get() - ir_rsigma_dtype = ir.F32Type.get() - ir_amax_type = ir.RankedTensorType(amax.type) - ir_amax_dtype = ir_amax_type.element_type - ir_amax_shape = ir_amax_type.shape - ir_scale_shape = ir_amax_shape - ir_scale_inv_shape = ir_amax_shape - - out_shape = x_shape - hidden_size = reduce(operator.mul, g_shape) - batch_shape = out_shape[:-1] - batch_size = reduce(operator.mul, x_shape) // hidden_size - - wkspace_aval, barrier_aval = ctx.avals_out[-2:] - - out_types = [ - ir.RankedTensorType.get(out_shape, ir_out_dtype), - ir.RankedTensorType.get(batch_shape, ir_mu_dtype), - ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype), - ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), - ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), - ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)), - ] - operands = [x, gamma, beta, amax, scale, scale_inv] - operand_shapes = [ - x_shape, - g_shape, - b_shape, - ir_amax_shape, - ir_scale_shape, - ir_scale_inv_shape, - ] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - sm_margin = get_forward_sm_margin() - - opaque = transformer_engine_jax.pack_norm_descriptor( - batch_size, - hidden_size, - wkspace_aval.size, - barrier_aval.size, - (0,), # no dgamma_part in FWD pass - (0,), # no dbeta_part in BWD pass - jax_dtype_to_te_dtype(x_aval.dtype), - jax_dtype_to_te_dtype(gamma_aval.dtype), - jax_dtype_to_te_dtype(wkspace_aval.dtype), - jax_dtype_to_te_dtype(barrier_aval.dtype), - TEDType.kByte, # dummy dgamma_part te_dtype - TEDType.kByte, # dummy dbeta_part te_dtype - zero_centered_gamma, - epsilon, - sm_margin, - ) + if is_ffi_enabled(): + name = "te_layernorm_forward_fp8_ffi" + sm_margin = get_forward_sm_margin() + out = ffi.ffi_lowering(name, operand_output_aliases={3: 3})( + ctx, + x, + gamma, + beta, + amax, + scale, + scale_inv, + zero_centered_gamma=zero_centered_gamma, + eps=epsilon, + sm_margin=sm_margin, + ) + else: + ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) + ir_mu_dtype = ir.F32Type.get() + ir_rsigma_dtype = ir.F32Type.get() + ir_amax_type = ir.RankedTensorType(amax.type) + ir_amax_dtype = ir_amax_type.element_type + ir_amax_shape = ir_amax_type.shape + ir_scale_shape = ir_amax_shape + ir_scale_inv_shape = ir_amax_shape + + out_shape = x_shape + hidden_size = reduce(operator.mul, g_shape) + batch_shape = out_shape[:-1] + batch_size = reduce(operator.mul, x_shape) // hidden_size + + wkspace_aval, barrier_aval = ctx.avals_out[-2:] + + out_types = [ + ir.RankedTensorType.get(out_shape, ir_out_dtype), + ir.RankedTensorType.get(batch_shape, ir_mu_dtype), + ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype), + ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), + ir.RankedTensorType.get( + wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype) + ), + ir.RankedTensorType.get( + barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype) + ), + ] + operands = [x, gamma, beta, amax, scale, scale_inv] + operand_shapes = [ + x_shape, + g_shape, + b_shape, + ir_amax_shape, + ir_scale_shape, + ir_scale_inv_shape, + ] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + sm_margin = get_forward_sm_margin() + + opaque = transformer_engine_jax.pack_norm_descriptor( + batch_size, + hidden_size, + wkspace_aval.size, + barrier_aval.size, + (0,), # no dgamma_part in FWD pass + (0,), # no dbeta_part in BWD pass + jax_dtype_to_te_dtype(x_aval.dtype), + jax_dtype_to_te_dtype(gamma_aval.dtype), + jax_dtype_to_te_dtype(wkspace_aval.dtype), + jax_dtype_to_te_dtype(barrier_aval.dtype), + TEDType.kByte, # dummy dgamma_part te_dtype + TEDType.kByte, # dummy dbeta_part te_dtype + zero_centered_gamma, + epsilon, + sm_margin, + ) - out = custom_caller( - LayerNormFwdFp8Primitive.name, args, opaque, False, operand_output_aliases={3: 3} - ) + out = custom_caller( + LayerNormFwdFp8Primitive.name, args, opaque, False, operand_output_aliases={3: 3} + ) return out @@ -1345,67 +1435,87 @@ def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon): # Currently only support casting to E4M3 only in C side. assert out_dtype == jnp.float8_e4m3fn - x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in - - assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert amax_aval.dtype == jnp.float32 - assert scale_aval.dtype == jnp.float32 - assert scale_inv_aval.dtype == jnp.float32 - - x_type = ir.RankedTensorType(x.type) - x_shape = x_type.shape - g_type = ir.RankedTensorType(gamma.type) - g_shape = g_type.shape - - ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) - ir_rsigma_dtype = ir.F32Type.get() - ir_amax_type = ir.RankedTensorType(amax.type) - ir_amax_dtype = ir_amax_type.element_type - ir_amax_shape = ir_amax_type.shape - ir_scale_shape = ir_amax_shape - ir_scale_inv_shape = ir_amax_shape - - out_shape = x_shape - hidden_size = reduce(operator.mul, g_shape) - batch_shape = out_shape[:-1] - batch_size = reduce(operator.mul, x_shape) // hidden_size - - wkspace_aval, barrier_aval = ctx.avals_out[-2:] - - out_types = [ - ir.RankedTensorType.get(out_shape, ir_out_dtype), - ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype), - ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), - ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), - ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)), - ] - operands = [x, gamma, amax, scale, scale_inv] - operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - sm_margin = get_forward_sm_margin() - - opaque = transformer_engine_jax.pack_norm_descriptor( - batch_size, - hidden_size, - wkspace_aval.size, - barrier_aval.size, - (0,), # no dgamma_part in FWD pass - (0,), # no dbeta_part in BWD pass - jax_dtype_to_te_dtype(x_aval.dtype), - jax_dtype_to_te_dtype(gamma_aval.dtype), - jax_dtype_to_te_dtype(wkspace_aval.dtype), - jax_dtype_to_te_dtype(barrier_aval.dtype), - TEDType.kByte, # dummy dgamma_part te_dtype - TEDType.kByte, # dummy dbeta_part te_dtype - False, # RMSNorm doesn't support zero_centered_gamma - epsilon, - sm_margin, - ) + if is_ffi_enabled(): + name = "te_rmsnorm_forward_fp8_ffi" + sm_margin = get_forward_sm_margin() + zero_centered_gamma = False # RMSNorm doesn't support zero_centered_gamma + out = ffi.ffi_lowering(name, operand_output_aliases={2: 2})( + ctx, + x, + gamma, + amax, + scale, + scale_inv, + zero_centered_gamma=zero_centered_gamma, + eps=epsilon, + sm_margin=sm_margin, + ) + else: + x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in + + assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert amax_aval.dtype == jnp.float32 + assert scale_aval.dtype == jnp.float32 + assert scale_inv_aval.dtype == jnp.float32 + + x_type = ir.RankedTensorType(x.type) + x_shape = x_type.shape + g_type = ir.RankedTensorType(gamma.type) + g_shape = g_type.shape + + ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) + ir_rsigma_dtype = ir.F32Type.get() + ir_amax_type = ir.RankedTensorType(amax.type) + ir_amax_dtype = ir_amax_type.element_type + ir_amax_shape = ir_amax_type.shape + ir_scale_shape = ir_amax_shape + ir_scale_inv_shape = ir_amax_shape + + out_shape = x_shape + hidden_size = reduce(operator.mul, g_shape) + batch_shape = out_shape[:-1] + batch_size = reduce(operator.mul, x_shape) // hidden_size + + wkspace_aval, barrier_aval = ctx.avals_out[-2:] + + out_types = [ + ir.RankedTensorType.get(out_shape, ir_out_dtype), + ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype), + ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), + ir.RankedTensorType.get( + wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype) + ), + ir.RankedTensorType.get( + barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype) + ), + ] + operands = [x, gamma, amax, scale, scale_inv] + operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + sm_margin = get_forward_sm_margin() + + opaque = transformer_engine_jax.pack_norm_descriptor( + batch_size, + hidden_size, + wkspace_aval.size, + barrier_aval.size, + (0,), # no dgamma_part in FWD pass + (0,), # no dbeta_part in BWD pass + jax_dtype_to_te_dtype(x_aval.dtype), + jax_dtype_to_te_dtype(gamma_aval.dtype), + jax_dtype_to_te_dtype(wkspace_aval.dtype), + jax_dtype_to_te_dtype(barrier_aval.dtype), + TEDType.kByte, # dummy dgamma_part te_dtype + TEDType.kByte, # dummy dbeta_part te_dtype + False, # RMSNorm doesn't support zero_centered_gamma + epsilon, + sm_margin, + ) - out = custom_caller( - RmsNormFwdFp8Primitive.name, args, opaque, False, operand_output_aliases={2: 2} - ) + out = custom_caller( + RmsNormFwdFp8Primitive.name, args, opaque, False, operand_output_aliases={2: 2} + ) return out diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 48bf4d969a..062bbbf0fb 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -9,6 +9,7 @@ from jax import dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding +from jax.extend import ffi from transformer_engine import transformer_engine_jax from transformer_engine.transformer_engine_jax import DType as TEDType @@ -20,6 +21,7 @@ check_valid_batch_dims, jax_dtype_to_te_dtype, jax_dtype_to_ir_dtype, + is_ffi_enabled, ) from ..sharding import all_reduce_max_along_all_axes_except_PP @@ -84,30 +86,36 @@ def lowering(ctx, x, amax, scale, scale_inv, *, out_dtype): assert amax_aval.dtype == jnp.float32 assert scale_aval.dtype == jnp.float32 assert scale_inv_aval.dtype == jnp.float32 - ir_x_type = ir.RankedTensorType(x.type) - ir_x_shape = ir_x_type.shape - ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) - ir_amax_type = ir.RankedTensorType(amax.type) - ir_amax_dtype = ir_amax_type.element_type - ir_amax_shape = ir_amax_type.shape - ir_scale_shape = ir_amax_shape - ir_scale_inv_shape = ir_amax_shape - - out_types = [ - ir.RankedTensorType.get(ir_x_shape, ir_out_dtype), - ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), - ] - operands = [x, amax, scale, scale_inv] - operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - opaque = transformer_engine_jax.pack_common_descriptor( - ir_x_shape, jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(out_dtype) - ) + if is_ffi_enabled(): + name = "te_quantize_ffi" + out = ffi.ffi_lowering(name, operand_output_aliases={1: 1})( + ctx, x, amax, scale, scale_inv + ) + else: + ir_x_type = ir.RankedTensorType(x.type) + ir_x_shape = ir_x_type.shape + ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) + ir_amax_type = ir.RankedTensorType(amax.type) + ir_amax_dtype = ir_amax_type.element_type + ir_amax_shape = ir_amax_type.shape + ir_scale_shape = ir_amax_shape + ir_scale_inv_shape = ir_amax_shape + + out_types = [ + ir.RankedTensorType.get(ir_x_shape, ir_out_dtype), + ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), + ] + operands = [x, amax, scale, scale_inv] + operand_shapes = [ir_x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + opaque = transformer_engine_jax.pack_common_descriptor( + ir_x_shape, jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(out_dtype) + ) - out = custom_caller( - CastFP8Primitive.name, args, opaque, False, operand_output_aliases={1: 1} - ) + out = custom_caller( + CastFP8Primitive.name, args, opaque, False, operand_output_aliases={1: 1} + ) return out diff --git a/transformer_engine/jax/cpp_extensions/softmax.py b/transformer_engine/jax/cpp_extensions/softmax.py index bf92c00de3..a12943f4c2 100644 --- a/transformer_engine/jax/cpp_extensions/softmax.py +++ b/transformer_engine/jax/cpp_extensions/softmax.py @@ -12,12 +12,13 @@ from jax import core, dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding +from jax.extend import ffi from transformer_engine import transformer_engine_jax from .base import BasePrimitive, register_primitive from .custom_call import custom_caller, CustomCallArgsWrapper -from .misc import get_padded_spec, check_valid_batch_dims, jax_dtype_to_te_dtype +from .misc import get_padded_spec, check_valid_batch_dims, jax_dtype_to_te_dtype, is_ffi_enabled from ..softmax import SoftmaxType @@ -133,32 +134,36 @@ def forward_lowering(name, ctx, logits, *, scale_factor): """ softmax_forward lowering rules """ - (i_aval,) = ctx.avals_in - i_type = ir.RankedTensorType(logits.type) - i_shape = i_type.shape - # Assume [...Batch, Head, Q_Seqlen, K_Seqlen] - batch = reduce(operator.mul, i_shape[:-3]) - pad_batch = batch - heads = i_shape[-3] - q_seqlen = i_shape[-2] - k_seqlen = i_shape[-1] - - out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)] - operands = [logits] - operand_shapes = [i_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - opaque = transformer_engine_jax.pack_softmax_descriptor( - batch, - pad_batch, - heads, - q_seqlen, - k_seqlen, - jax_dtype_to_te_dtype(i_aval.dtype), - scale_factor, - ) + if is_ffi_enabled(): + ffi_name = name + "_ffi" + out = ffi.ffi_lowering(ffi_name)(ctx, logits, scale_factor=scale_factor) + else: + (i_aval,) = ctx.avals_in + i_type = ir.RankedTensorType(logits.type) + i_shape = i_type.shape + # Assume [...Batch, Head, Q_Seqlen, K_Seqlen] + batch = reduce(operator.mul, i_shape[:-3]) + pad_batch = batch + heads = i_shape[-3] + q_seqlen = i_shape[-2] + k_seqlen = i_shape[-1] + + out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)] + operands = [logits] + operand_shapes = [i_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + opaque = transformer_engine_jax.pack_softmax_descriptor( + batch, + pad_batch, + heads, + q_seqlen, + k_seqlen, + jax_dtype_to_te_dtype(i_aval.dtype), + scale_factor, + ) - out = custom_caller(name, args, opaque, False) + out = custom_caller(name, args, opaque, False) return out @@ -240,37 +245,41 @@ def backward_lowering(name, ctx, dz, softmax_out, *, scale_factor): """ softmax_backward lowering rules """ - dz_aval, _ = ctx.avals_in - - dz_type = ir.RankedTensorType(dz.type) - dz_shape = dz_type.shape - - # Assume [...Batch, Head, Q_Seqlen, K_Seqlen] - batch = reduce(operator.mul, dz_shape[:-3]) - pad_batch = batch # unused - heads = dz_shape[-3] - q_seqlen = dz_shape[-2] - k_seqlen = dz_shape[-1] - - softmax_out_type = ir.RankedTensorType(softmax_out.type) - softmax_out_shape = softmax_out_type.shape - - out_types = [ir.RankedTensorType.get(dz_shape, dz_type.element_type)] - operands = [dz, softmax_out] - operand_shapes = [dz_shape, softmax_out_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - opaque = transformer_engine_jax.pack_softmax_descriptor( - batch, - pad_batch, - heads, - q_seqlen, - k_seqlen, - jax_dtype_to_te_dtype(dz_aval.dtype), - scale_factor, - ) + if is_ffi_enabled(): + ffi_name = name + "_ffi" + out = ffi.ffi_lowering(ffi_name)(ctx, dz, softmax_out, scale_factor=scale_factor) + else: + dz_aval, _ = ctx.avals_in + + dz_type = ir.RankedTensorType(dz.type) + dz_shape = dz_type.shape + + # Assume [...Batch, Head, Q_Seqlen, K_Seqlen] + batch = reduce(operator.mul, dz_shape[:-3]) + pad_batch = batch # unused + heads = dz_shape[-3] + q_seqlen = dz_shape[-2] + k_seqlen = dz_shape[-1] + + softmax_out_type = ir.RankedTensorType(softmax_out.type) + softmax_out_shape = softmax_out_type.shape + + out_types = [ir.RankedTensorType.get(dz_shape, dz_type.element_type)] + operands = [dz, softmax_out] + operand_shapes = [dz_shape, softmax_out_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + opaque = transformer_engine_jax.pack_softmax_descriptor( + batch, + pad_batch, + heads, + q_seqlen, + k_seqlen, + jax_dtype_to_te_dtype(dz_aval.dtype), + scale_factor, + ) - out = custom_caller(name, args, opaque, False) + out = custom_caller(name, args, opaque, False) return out @@ -577,36 +586,39 @@ def lowering(ctx, logits, mask, *, scale_factor): """ te_scaled_masked_softmax_forward lowering rules """ + if is_ffi_enabled(): + ffi_name = "te_scaled_masked_softmax_forward_ffi" + out = ffi.ffi_lowering(ffi_name)(ctx, logits, mask, scale_factor=scale_factor) + else: + logits_aval, _ = ctx.avals_in + i_type = ir.RankedTensorType(logits.type) + i_shape = i_type.shape + # Assume [...Batch, Head, Q_Seqlen, K_Seqlen] + batch = reduce(operator.mul, i_shape[:-3]) + heads = i_shape[-3] + q_seqlen = i_shape[-2] + k_seqlen = i_shape[-1] + + mask_type = ir.RankedTensorType(mask.type) + mask_shape = mask_type.shape + pad_batch = reduce(operator.mul, mask_shape[:-3]) + + out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)] + operands = [logits, mask] + operand_shapes = [i_shape, mask_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + opaque = transformer_engine_jax.pack_softmax_descriptor( + batch, + pad_batch, + heads, + q_seqlen, + k_seqlen, + jax_dtype_to_te_dtype(logits_aval.dtype), + scale_factor, + ) - logits_aval, _ = ctx.avals_in - i_type = ir.RankedTensorType(logits.type) - i_shape = i_type.shape - # Assume [...Batch, Head, Q_Seqlen, K_Seqlen] - batch = reduce(operator.mul, i_shape[:-3]) - heads = i_shape[-3] - q_seqlen = i_shape[-2] - k_seqlen = i_shape[-1] - - mask_type = ir.RankedTensorType(mask.type) - mask_shape = mask_type.shape - pad_batch = reduce(operator.mul, mask_shape[:-3]) - - out_types = [ir.RankedTensorType.get(i_shape, i_type.element_type)] - operands = [logits, mask] - operand_shapes = [i_shape, mask_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - opaque = transformer_engine_jax.pack_softmax_descriptor( - batch, - pad_batch, - heads, - q_seqlen, - k_seqlen, - jax_dtype_to_te_dtype(logits_aval.dtype), - scale_factor, - ) - - out = custom_caller(ScaledMaskedSoftmaxFwdPrimitive.name, args, opaque, False) + out = custom_caller(ScaledMaskedSoftmaxFwdPrimitive.name, args, opaque, False) return out diff --git a/transformer_engine/jax/cpp_extensions/transpose.py b/transformer_engine/jax/cpp_extensions/transpose.py index 963d7f09e8..2338572e30 100644 --- a/transformer_engine/jax/cpp_extensions/transpose.py +++ b/transformer_engine/jax/cpp_extensions/transpose.py @@ -64,6 +64,35 @@ def _jax_cast_transpose( return casted_output, casted_transposed_output, updated_amax +def _jax_dbias_cast_transpose( + dz, amax, scale, out_dtype, static_axis_boundary, transpose_axis_boundary +): + """ + JAX native dbias_cast_transpose implementation + """ + casted_dz, cast_transposed_dz, updated_amax = _jax_cast_transpose( + dz, + scale, + amax, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary, + ) + dbias = jnp.sum( + dz, + axis=tuple( + range( + transpose_axis_boundary + if transpose_axis_boundary > 0 + else transpose_axis_boundary + dz.ndim + ) + ), + keepdims=False, + ) + dbias = dbias.ravel() # C++ function returns an 1D array for dbias + return casted_dz, cast_transposed_dz, dbias, updated_amax + + class TransposePrimitive(BasePrimitive): """ Transpose Primitive @@ -102,32 +131,36 @@ def lowering(ctx, x, *, static_axis_boundary, transpose_axis_boundary): jnp.float8_e5m2, ] - ir_x_type = ir.RankedTensorType(x.type) - ir_x_shape = ir_x_type.shape - ir_out_dtype = jax_dtype_to_ir_dtype(x_aval.dtype) - if static_axis_boundary >= 0: - for i in range(static_axis_boundary + 1): - assert ir_x_shape[i] == 1 + if is_ffi_enabled(): + name = "te_transpose_ffi" + out = ffi.ffi_lowering(name)(ctx, x, transpose_axis=transpose_axis_boundary) + else: + ir_x_type = ir.RankedTensorType(x.type) + ir_x_shape = ir_x_type.shape + ir_out_dtype = jax_dtype_to_ir_dtype(x_aval.dtype) + if static_axis_boundary >= 0: + for i in range(static_axis_boundary + 1): + assert ir_x_shape[i] == 1 - transposed_x_shape = multidim_transpose( - ir_x_shape, static_axis_boundary, transpose_axis_boundary - ) + transposed_x_shape = multidim_transpose( + ir_x_shape, static_axis_boundary, transpose_axis_boundary + ) - out_types = [ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype)] - operands = [x] - operand_shapes = [ir_x_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + out_types = [ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype)] + operands = [x] + operand_shapes = [ir_x_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - te_dtype = jax_dtype_to_te_dtype(x_aval.dtype) - contracted_x_shape = ( - reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]), - reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]), - ) - opaque = transformer_engine_jax.pack_common_descriptor( - contracted_x_shape, te_dtype, te_dtype - ) + te_dtype = jax_dtype_to_te_dtype(x_aval.dtype) + contracted_x_shape = ( + reduce(operator.mul, ir_x_shape[:transpose_axis_boundary]), + reduce(operator.mul, ir_x_shape[transpose_axis_boundary:]), + ) + opaque = transformer_engine_jax.pack_common_descriptor( + contracted_x_shape, te_dtype, te_dtype + ) - out = custom_caller(TransposePrimitive.name, args, opaque, False) + out = custom_caller(TransposePrimitive.name, args, opaque, False) return out @@ -415,12 +448,7 @@ def cast_transpose( """ if not CastTransposePrimitive.enabled(): return _jax_cast_transpose( - x, - scale, - amax, - out_dtype=out_dtype, - static_axis_boundary=static_axis_boundary, - transpose_axis_boundary=transpose_axis_boundary, + x, scale, amax, out_dtype, static_axis_boundary, transpose_axis_boundary ) return CastTransposePrimitive.outer_primitive.bind( x, @@ -508,45 +536,53 @@ def lowering( assert amax_aval.dtype == jnp.float32 assert scale_aval.dtype == jnp.float32 assert scale_inv_aval.dtype == jnp.float32 - ir_dz_type = ir.RankedTensorType(dz.type) - ir_dz_shape = ir_dz_type.shape - batch_size = reduce(operator.mul, ir_dz_shape[:transpose_axis_boundary]) - ir_hidden_size = reduce(operator.mul, ir_dz_shape[transpose_axis_boundary:]) - contracted_dz_shape = (batch_size, ir_hidden_size) - ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) - ir_amax_type = ir.RankedTensorType(amax.type) - ir_amax_dtype = ir_amax_type.element_type - ir_amax_shape = ir_amax_type.shape - ir_scale_shape = ir_amax_shape - ir_scale_inv_shape = ir_amax_shape - transposed_dz_shape = multidim_transpose( - ir_dz_shape, static_axis_boundary, transpose_axis_boundary - ) - dbias_shape = (*ir_dz_shape[: static_axis_boundary + 1], ir_hidden_size) + if is_ffi_enabled(): + name = "te_dbias_cast_transpose_ffi" + out = ffi.ffi_lowering(name, operand_output_aliases={1: 3})( + ctx, dz, amax, scale, scale_inv, transpose_axis=transpose_axis_boundary + ) + else: + ir_dz_type = ir.RankedTensorType(dz.type) + ir_dz_shape = ir_dz_type.shape + batch_size = reduce(operator.mul, ir_dz_shape[:transpose_axis_boundary]) + ir_hidden_size = reduce(operator.mul, ir_dz_shape[transpose_axis_boundary:]) + contracted_dz_shape = (batch_size, ir_hidden_size) + ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) + ir_amax_type = ir.RankedTensorType(amax.type) + ir_amax_dtype = ir_amax_type.element_type + ir_amax_shape = ir_amax_type.shape + ir_scale_shape = ir_amax_shape + ir_scale_inv_shape = ir_amax_shape + transposed_dz_shape = multidim_transpose( + ir_dz_shape, static_axis_boundary, transpose_axis_boundary + ) + dbias_shape = (*ir_dz_shape[: static_axis_boundary + 1], ir_hidden_size) - wkspace_aval = ctx.avals_out[-1] + wkspace_aval = ctx.avals_out[-1] - out_types = [ - ir.RankedTensorType.get(ir_dz_shape, ir_out_dtype), - ir.RankedTensorType.get(transposed_dz_shape, ir_out_dtype), - ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type), - ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), - ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), - ] - operands = [dz, amax, scale, scale_inv] - operand_shapes = [ir_dz_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - opaque = transformer_engine_jax.pack_common_wk_descriptor( - contracted_dz_shape, - wkspace_aval.shape, - jax_dtype_to_te_dtype(dz_aval.dtype), - jax_dtype_to_te_dtype(out_dtype), - jax_dtype_to_te_dtype(wkspace_aval.dtype), - ) + out_types = [ + ir.RankedTensorType.get(ir_dz_shape, ir_out_dtype), + ir.RankedTensorType.get(transposed_dz_shape, ir_out_dtype), + ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type), + ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), + ir.RankedTensorType.get( + wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype) + ), + ] + operands = [dz, amax, scale, scale_inv] + operand_shapes = [ir_dz_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + opaque = transformer_engine_jax.pack_common_wk_descriptor( + contracted_dz_shape, + wkspace_aval.shape, + jax_dtype_to_te_dtype(dz_aval.dtype), + jax_dtype_to_te_dtype(out_dtype), + jax_dtype_to_te_dtype(wkspace_aval.dtype), + ) - out = custom_caller( - DBiasCastTransposePrimitive.name, args, opaque, False, operand_output_aliases={1: 3} - ) + out = custom_caller( + DBiasCastTransposePrimitive.name, args, opaque, False, operand_output_aliases={1: 3} + ) return out @@ -673,26 +709,9 @@ def dbias_cast_transpose( static_axis_boundary = -1 # means no static axes if not DBiasCastTransposePrimitive.enabled(): - casted_dz, cast_transposed_dz, updated_amax = _jax_cast_transpose( - dz, - scale, - amax, - out_dtype=out_dtype, - static_axis_boundary=static_axis_boundary, - transpose_axis_boundary=transpose_axis_boundary, - ) - dbias = jnp.sum( - dz, - axis=tuple( - range( - transpose_axis_boundary - if transpose_axis_boundary > 0 - else transpose_axis_boundary + dz.ndim - ) - ), - keepdims=False, + return _jax_dbias_cast_transpose( + dz, amax, scale, out_dtype, static_axis_boundary, transpose_axis_boundary ) - return casted_dz, cast_transposed_dz, dbias, updated_amax return DBiasCastTransposePrimitive.outer_primitive.bind( dz, @@ -712,8 +731,8 @@ class DActLuDBiasCastTransposePrimitive(BasePrimitive): name = "te_dact_lu_dbias_cast_transpose" multiple_results = True - # out_dtype, static_axis_boundary, transpose_axis_boundary, act_enum - impl_static_args = (5, 6, 7, 8) + # out_dtype, static_axis_boundary, act_enum + impl_static_args = (5, 6, 7) inner_primitive = None outer_primitive = None @@ -727,7 +746,6 @@ def abstract( *, out_dtype, static_axis_boundary, - transpose_axis_boundary, act_enum ): # pylint: disable=unused-argument """ @@ -742,7 +760,7 @@ def abstract( ir_hidden_szie = dz_aval.shape[-1] gi_hidden_size = x_aval.shape[-1] assert ir_hidden_szie == gi_hidden_size - t_shape = multidim_transpose(x_aval.shape, static_axis_boundary, transpose_axis_boundary) + t_shape = multidim_transpose(x_aval.shape, static_axis_boundary, -2) out = dz_aval.update(shape=x_aval.shape, dtype=out_dtype) t_out = dz_aval.update(shape=t_shape, dtype=out_dtype) @@ -775,19 +793,7 @@ def outer_abstract(*args, **kwargs): return out, t_out, dbias, updated_amax_aval @staticmethod - def lowering( - ctx, - dz, - x, - amax, - scale, - scale_inv, - *, - out_dtype, - static_axis_boundary, - transpose_axis_boundary, - act_enum - ): + def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, act_enum): """ te_dgated_act_lu_cast_transpose_p lowering rules """ @@ -797,55 +803,67 @@ def lowering( assert amax_aval.dtype == jnp.float32 assert scale_aval.dtype == jnp.float32 assert scale_inv_aval.dtype == jnp.float32 - ir_dz_type = ir.RankedTensorType(dz.type) - ir_dz_shape = ir_dz_type.shape - x_type = ir.RankedTensorType(x.type) - x_shape = x_type.shape - dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1]) - x_batch_size = reduce(operator.mul, x_shape[:-2]) - assert dz_batch_szie == x_batch_size - ir_hidden_szie = ir_dz_shape[-1] - contracted_x_shape = (x_batch_size, ir_hidden_szie) - - ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) - ir_amax_type = ir.RankedTensorType(amax.type) - ir_amax_dtype = ir_amax_type.element_type - ir_amax_shape = ir_amax_type.shape - ir_scale_shape = ir_amax_shape - ir_scale_inv_shape = ir_amax_shape - transposed_x_shape = multidim_transpose( - x_shape, static_axis_boundary, transpose_axis_boundary - ) - dbias_shape = (*x_shape[: static_axis_boundary + 1], ir_hidden_szie) + if is_ffi_enabled(): + name = "te_dact_lu_dbias_cast_transpose_ffi" + out = ffi.ffi_lowering(name, operand_output_aliases={2: 3})( + ctx, dz, x, amax, scale, scale_inv, act_enum=int(act_enum) + ) + else: + ir_dz_type = ir.RankedTensorType(dz.type) + ir_dz_shape = ir_dz_type.shape + x_type = ir.RankedTensorType(x.type) + x_shape = x_type.shape + dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1]) + x_batch_size = reduce(operator.mul, x_shape[:-2]) + assert dz_batch_szie == x_batch_size + ir_hidden_szie = ir_dz_shape[-1] + contracted_x_shape = (x_batch_size, ir_hidden_szie) - wkspace_aval = ctx.avals_out[-1] + ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) + ir_amax_type = ir.RankedTensorType(amax.type) + ir_amax_dtype = ir_amax_type.element_type + ir_amax_shape = ir_amax_type.shape + ir_scale_shape = ir_amax_shape + ir_scale_inv_shape = ir_amax_shape + transposed_x_shape = multidim_transpose(x_shape, static_axis_boundary, -2) + dbias_shape = (*x_shape[: static_axis_boundary + 1], ir_hidden_szie) - out_types = [ - ir.RankedTensorType.get(x_shape, ir_out_dtype), - ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype), - ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type), - ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), - ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), - ] - operands = [dz, x, amax, scale, scale_inv] - operand_shapes = [ir_dz_shape, x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - opaque = transformer_engine_jax.pack_common_wk_descriptor( - contracted_x_shape, - wkspace_aval.shape, - jax_dtype_to_te_dtype(dz_aval.dtype), - jax_dtype_to_te_dtype(out_dtype), - jax_dtype_to_te_dtype(wkspace_aval.dtype), - act_enum, - ) + wkspace_aval = ctx.avals_out[-1] - out = custom_caller( - DActLuDBiasCastTransposePrimitive.name, - args, - opaque, - False, - operand_output_aliases={2: 3}, - ) + out_types = [ + ir.RankedTensorType.get(x_shape, ir_out_dtype), + ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype), + ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type), + ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), + ir.RankedTensorType.get( + wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype) + ), + ] + operands = [dz, x, amax, scale, scale_inv] + operand_shapes = [ + ir_dz_shape, + x_shape, + ir_amax_shape, + ir_scale_shape, + ir_scale_inv_shape, + ] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + opaque = transformer_engine_jax.pack_common_wk_descriptor( + contracted_x_shape, + wkspace_aval.shape, + jax_dtype_to_te_dtype(dz_aval.dtype), + jax_dtype_to_te_dtype(out_dtype), + jax_dtype_to_te_dtype(wkspace_aval.dtype), + act_enum, + ) + + out = custom_caller( + DActLuDBiasCastTransposePrimitive.name, + args, + opaque, + False, + operand_output_aliases={2: 3}, + ) return out @@ -858,7 +876,6 @@ def impl( scale_inv, out_dtype, static_axis_boundary, - transpose_axis_boundary, act_enum, ): """ @@ -873,21 +890,12 @@ def impl( scale_inv, out_dtype=out_dtype, static_axis_boundary=static_axis_boundary, - transpose_axis_boundary=transpose_axis_boundary, act_enum=act_enum, ) return out, t_out, dbias, updated_amax @staticmethod - def batcher( - batched_args, - batch_dims, - *, - out_dtype, - static_axis_boundary, - transpose_axis_boundary, - act_enum - ): + def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary, act_enum): """ to describe batch rules for vmap """ @@ -897,10 +905,6 @@ def batcher( dz, x, amax, scale, scale_inv = batched_args x_bdim, _, amax_bdim, _, _ = batch_dims - # Minus batch dim. - transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, x.ndim - 1) - transpose_axis_boundary += 1 # Plus batch dim - out_bdims = x_bdim, x_bdim, x_bdim, amax_bdim return ( DActLuDBiasCastTransposePrimitive.outer_primitive.bind( @@ -911,7 +915,6 @@ def batcher( scale_inv, out_dtype=out_dtype, static_axis_boundary=x_bdim, - transpose_axis_boundary=transpose_axis_boundary, act_enum=act_enum, ), out_bdims, @@ -921,7 +924,6 @@ def batcher( def infer_sharding_from_operands( out_dtype, static_axis_boundary, - transpose_axis_boundary, act_enum, mesh, arg_infos, @@ -930,7 +932,7 @@ def infer_sharding_from_operands( del out_dtype, result_infos, act_enum x_spec = get_padded_spec(arg_infos[1]) out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) - xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) + xt_spec = multidim_transpose(x_spec, static_axis_boundary, -2) tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) dbias_shaprding = NamedSharding( mesh, PartitionSpec(*x_spec[: static_axis_boundary + 1], x_spec[-1]) @@ -942,7 +944,6 @@ def infer_sharding_from_operands( def partition( out_dtype, static_axis_boundary, - transpose_axis_boundary, act_enum, mesh, arg_infos, @@ -951,7 +952,7 @@ def partition( del result_infos x_spec = get_padded_spec(arg_infos[1]) casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) - xt_spec = multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) + xt_spec = multidim_transpose(x_spec, static_axis_boundary, -2) casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) dbias_shaprding = NamedSharding( @@ -977,7 +978,6 @@ def sharded_impl(dz, x, amax, scale, scale_inv): scale_inv, out_dtype=out_dtype, static_axis_boundary=static_axis_boundary, - transpose_axis_boundary=transpose_axis_boundary, act_enum=act_enum, ) ) @@ -999,7 +999,6 @@ def dact_lu_dbias_cast_transpose( scale_inv: jnp.ndarray, out_dtype: TEDType, static_axis_boundary: int, - transpose_axis_boundary: int = -1, activation_type: Sequence[Union[str, Callable]] = ("gelu",), ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """ @@ -1013,27 +1012,10 @@ def dact_lu_dbias_cast_transpose( if not DActLuDBiasCastTransposePrimitive.enabled(): _, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), x) (dx,) = vjp_func(dz) - casted_dx, cast_transposed_dx, updated_amax = _jax_cast_transpose( - dx, - scale, - amax, - out_dtype=out_dtype, - static_axis_boundary=static_axis_boundary, - transpose_axis_boundary=transpose_axis_boundary, - ) - dbias = jnp.squeeze( - jnp.sum( - dx, - axis=tuple( - range( - transpose_axis_boundary - if transpose_axis_boundary > 0 - else transpose_axis_boundary + dx.ndim - ) - ), - ) + transpose_axis_boundary = -2 + return _jax_dbias_cast_transpose( + dx, amax, scale, out_dtype, static_axis_boundary, transpose_axis_boundary ) - return casted_dx, cast_transposed_dx, dbias, updated_amax act_type_id = ActivationEnum[activation_type] return DActLuDBiasCastTransposePrimitive.outer_primitive.bind( @@ -1044,7 +1026,6 @@ def dact_lu_dbias_cast_transpose( scale_inv, out_dtype=out_dtype, static_axis_boundary=static_axis_boundary, - transpose_axis_boundary=transpose_axis_boundary, act_enum=act_type_id, ) @@ -1102,47 +1083,59 @@ def lowering(ctx, dz, x, amax, scale, scale_inv, *, out_dtype, static_axis_bound assert amax_aval.dtype == jnp.float32 assert scale_aval.dtype == jnp.float32 assert scale_inv_aval.dtype == jnp.float32 - ir_dz_type = ir.RankedTensorType(dz.type) - ir_dz_shape = ir_dz_type.shape - x_type = ir.RankedTensorType(x.type) - x_shape = x_type.shape - dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1]) - x_batch_size = reduce(operator.mul, x_shape[:-2]) - assert dz_batch_szie == x_batch_size - assert x_shape[-2] == 2 # Linear + GeLU - ir_hidden_szie = ir_dz_shape[-1] - gi_hidden_size = x_shape[-1] - assert ir_hidden_szie == gi_hidden_size - ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) - ir_amax_type = ir.RankedTensorType(amax.type) - ir_amax_dtype = ir_amax_type.element_type - ir_amax_shape = ir_amax_type.shape - ir_scale_shape = ir_amax_shape - ir_scale_inv_shape = ir_amax_shape - transposed_x_shape = multidim_transpose(x_shape, static_axis_boundary, -2) - out_types = [ - ir.RankedTensorType.get(x_shape, ir_out_dtype), - ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype), - ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), - ] - operands = [dz, x, amax, scale, scale_inv] - operand_shapes = [ir_dz_shape, x_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - contracted_x_shape = (x_batch_size, x_shape[-1]) - opaque = transformer_engine_jax.pack_common_descriptor( - contracted_x_shape, - jax_dtype_to_te_dtype(dz_aval.dtype), - jax_dtype_to_te_dtype(out_dtype), - act_enum, - ) + if is_ffi_enabled(): + name = "te_dgated_act_lu_cast_transpose_ffi" + out = ffi.ffi_lowering(name, operand_output_aliases={2: 2})( + ctx, dz, x, amax, scale, scale_inv, act_enum=int(act_enum) + ) + else: + ir_dz_type = ir.RankedTensorType(dz.type) + ir_dz_shape = ir_dz_type.shape + x_type = ir.RankedTensorType(x.type) + x_shape = x_type.shape + dz_batch_szie = reduce(operator.mul, ir_dz_shape[:-1]) + x_batch_size = reduce(operator.mul, x_shape[:-2]) + assert dz_batch_szie == x_batch_size + assert x_shape[-2] == 2 # Linear + GeLU + ir_hidden_szie = ir_dz_shape[-1] + gi_hidden_size = x_shape[-1] + assert ir_hidden_szie == gi_hidden_size + ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) + ir_amax_type = ir.RankedTensorType(amax.type) + ir_amax_dtype = ir_amax_type.element_type + ir_amax_shape = ir_amax_type.shape + ir_scale_shape = ir_amax_shape + ir_scale_inv_shape = ir_amax_shape + transposed_x_shape = multidim_transpose(x_shape, static_axis_boundary, -2) + out_types = [ + ir.RankedTensorType.get(x_shape, ir_out_dtype), + ir.RankedTensorType.get(transposed_x_shape, ir_out_dtype), + ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), + ] + operands = [dz, x, amax, scale, scale_inv] + operand_shapes = [ + ir_dz_shape, + x_shape, + ir_amax_shape, + ir_scale_shape, + ir_scale_inv_shape, + ] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + contracted_x_shape = (x_batch_size, x_shape[-1]) + opaque = transformer_engine_jax.pack_common_descriptor( + contracted_x_shape, + jax_dtype_to_te_dtype(dz_aval.dtype), + jax_dtype_to_te_dtype(out_dtype), + act_enum, + ) - out = custom_caller( - DgatedActLuCastTransposePrimitive.name, - args, - opaque, - False, - operand_output_aliases={2: 2}, - ) + out = custom_caller( + DgatedActLuCastTransposePrimitive.name, + args, + opaque, + False, + operand_output_aliases={2: 2}, + ) return out diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index c233177e28..02e6aaf9d5 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -151,26 +151,32 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor( void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +XLA_FFI_DECLARE_HANDLER_SYMBOL(TransposeHandler); + void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +XLA_FFI_DECLARE_HANDLER_SYMBOL(CastTransposeHandler); + pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype); -XLA_FFI_DECLARE_HANDLER_SYMBOL(CastTransposeHandler); - void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasCastTransposeHandler); + // Activation size_t get_activation_len(NVTE_Activation_Type activation_enum); void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler); + void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); -void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuFP8Handler); -XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler); +void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuHandler); @@ -180,9 +186,13 @@ pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_ void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasCastTransposeHandler); + void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +XLA_FFI_DECLARE_HANDLER_SYMBOL(DGatedActLuCastTransposeHandler); + // Normalization pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, @@ -192,9 +202,13 @@ pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidd void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +XLA_FFI_DECLARE_HANDLER_SYMBOL(LayerNormForwardHandler); + void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +XLA_FFI_DECLARE_HANDLER_SYMBOL(LayerNormForwardFP8Handler); + pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, bool is_layer_norm, bool zero_centered_gamma, @@ -202,18 +216,30 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hid void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +XLA_FFI_DECLARE_HANDLER_SYMBOL(LayerNormBackwardHandler); + void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +XLA_FFI_DECLARE_HANDLER_SYMBOL(RMSNormForwardHandler); + void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +XLA_FFI_DECLARE_HANDLER_SYMBOL(RMSNormForwardFP8Handler); + void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +XLA_FFI_DECLARE_HANDLER_SYMBOL(RMSNormBackwardHandler); + // Quantization void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +XLA_FFI_DECLARE_HANDLER_SYMBOL(QuantizeHandler); + void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler); + // Softmax void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque, @@ -234,8 +260,23 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len); +XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler); + +XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxBackwardHandler); + +XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledMaskedSoftmaxForwardHandler); + +XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledMaskedSoftmaxBackwardHandler); + +XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxForwardHandler); + +XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxBackwardHandler); + // Attention +// Cudnn helpers +XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler); + NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float dropout_probability, @@ -253,6 +294,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler); + pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, @@ -263,6 +306,8 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index 1e8998b365..a2090bceba 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -110,7 +110,7 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type outp auto *output = output_buf->untyped_data(); auto input_dims = input_buf.dimensions(); - auto m = std::accumulate(input_dims.begin(), input_dims.end() - 2, 1, std::multiplies<>()); + auto m = product(input_dims, 0, input_dims.size() - 2); auto n = input_dims.back(); auto act_len = input_dims.end()[-2]; auto act_type = static_cast(act_enum); @@ -126,7 +126,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, .Ctx() // stream .Arg() // input .Ret() // output - .Attr("act_enum")); + .Attr("act_enum"), + FFI_CudaGraph_Traits); void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { auto *input = buffers[0]; @@ -152,6 +153,51 @@ void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t op act_enum, act_len); } +Error_Type ActLuFP8FFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, + Buffer_Type scale_buf, Buffer_Type scale_inv_buf, Result_Type output_buf, + Result_Type amax_out_buf, int64_t act_enum) { + auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); + + auto *input = input_buf.untyped_data(); + float *amax = reinterpret_cast(amax_buf.untyped_data()); + float *scale = reinterpret_cast(scale_buf.untyped_data()); + float *scale_inv = reinterpret_cast(scale_inv_buf.untyped_data()); + + auto *output = output_buf->untyped_data(); + float *amax_out = reinterpret_cast(amax_out_buf->untyped_data()); + NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX ActLuFP8 primitive."); + + if (!use_fp8(out_dtype)) { + scale = nullptr; + scale_inv = nullptr; + amax_out = nullptr; + } + + auto input_dims = input_buf.dimensions(); + auto m = product(input_dims, 0, input_dims.size() - 2); + auto n = input_dims.back(); + auto act_len = input_dims.end()[-2]; + auto act_type = static_cast(act_enum); + + ActLuImpl(input, m, n, in_dtype, out_dtype, scale, stream, scale_inv, amax_out, output, act_type, + act_len); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuFP8Handler, ActLuFP8FFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // amax + .Arg() // scale + .Arg() // scale_inv + .Ret() // output + .Ret() // amax_out + .Attr("act_enum"), + FFI_CudaGraph_Traits); + void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { auto *input = buffers[0]; auto *act_input = buffers[1]; @@ -218,9 +264,8 @@ Error_Type DActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act auto *output = output_buf->untyped_data(); auto act_input_dims = act_input_buf.dimensions(); - auto m = - std::accumulate(act_input_dims.begin(), act_input_dims.end() - 2, 1, std::multiplies<>()); - auto n = act_input_dims.back(); + auto m = static_cast(product(act_input_dims, 0, act_input_dims.size() - 2)); + auto n = static_cast(act_input_dims.back()); auto act_len = act_input_dims.end()[-2]; auto input_shape = std::vector{m, n}; @@ -276,7 +321,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuHandler, DActLuFFI, .Arg() // input .Arg() // act_input .Ret() // output - .Attr("act_enum")); + .Attr("act_enum"), + FFI_CudaGraph_Traits); pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype) { @@ -327,7 +373,7 @@ void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *o auto m = desc.shape.dims[0]; auto n = desc.shape.dims[1]; auto act_enum = static_cast(desc.act_enum); - ; + auto input_shape = std::vector{m, n}; auto act_input_shape = std::vector{m, n}; auto output_shape = std::vector{m, n}; @@ -376,6 +422,107 @@ void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *o } } +Error_Type DActLuDBiasCastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, + Buffer_Type act_input_buf, Buffer_Type amax_buf, + Buffer_Type scale_buf, Buffer_Type scale_inv_buf, + Result_Type output_buf, Result_Type output_trans_buf, + Result_Type dbias_buf, Result_Type amax_out_buf, + Result_Type workspace_buf, int64_t act_enum) { + auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); + auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); + + auto *input = input_buf.untyped_data(); + auto *act_input = act_input_buf.untyped_data(); + float *amax = reinterpret_cast(amax_buf.untyped_data()); + float *scale = reinterpret_cast(scale_buf.untyped_data()); + float *scale_inv = reinterpret_cast(scale_inv_buf.untyped_data()); + auto *output = output_buf->untyped_data(); + auto *output_trans = output_trans_buf->untyped_data(); + auto *dbias = dbias_buf->untyped_data(); + float *amax_out = reinterpret_cast(amax_out_buf->untyped_data()); + void *workspace = workspace_buf->untyped_data(); + NVTE_CHECK(amax == amax_out, + "amax not bound to amax_out in TE/JAX DActLuDBiasCastTranspose primitive."); + if (!use_fp8(out_dtype)) { + scale = nullptr; + scale_inv = nullptr; + amax_out = nullptr; + } + + auto input_dims = input_buf.dimensions(); + auto act_input_dims = act_input_buf.dimensions(); + auto workspace_dims = workspace_buf->dimensions(); + // m = x_batch_size = reduce(operator.mul, x_shape[:-2]), x_shape == act_input_dims + // n = ir_dz_shape[-1], ir_dz_shape == input_dims + auto input_ranks = input_dims.size(); + auto m = product(act_input_dims, 0, act_input_dims.size() - 2); + auto n = product(input_dims, input_ranks - 1, input_ranks); + auto input_shape = std::vector{m, n}; + auto act_input_shape = std::vector{m, n}; + auto output_shape = std::vector{m, n}; + auto output_trans_shape = std::vector{n, m}; + auto dbias_shape = std::vector{n}; + std::vector workspace_shape(workspace_dims.begin(), workspace_dims.end()); + + auto input_tensor = TensorWrapper(input, input_shape, in_dtype); + auto act_input_tensor = TensorWrapper(act_input, input_shape, in_dtype); + auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv); + auto output_trans_tensor = + TensorWrapper(output_trans, output_trans_shape, out_dtype, amax_out, scale, scale_inv); + auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype); + auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype); + + auto act_type = static_cast(act_enum); + switch (act_type) { + case NVTE_Activation_Type::GELU: + nvte_cast_transpose_dbias_dgelu(input_tensor.data(), act_input_tensor.data(), + output_tensor.data(), output_trans_tensor.data(), + dbias_tensor.data(), workspace_tensor.data(), stream); + break; + case NVTE_Activation_Type::SILU: + nvte_cast_transpose_dbias_dsilu(input_tensor.data(), act_input_tensor.data(), + output_tensor.data(), output_trans_tensor.data(), + dbias_tensor.data(), workspace_tensor.data(), stream); + break; + case NVTE_Activation_Type::RELU: + nvte_cast_transpose_dbias_drelu(input_tensor.data(), act_input_tensor.data(), + output_tensor.data(), output_trans_tensor.data(), + dbias_tensor.data(), workspace_tensor.data(), stream); + break; + case NVTE_Activation_Type::QGELU: + nvte_cast_transpose_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(), + output_tensor.data(), output_trans_tensor.data(), + dbias_tensor.data(), workspace_tensor.data(), stream); + break; + case NVTE_Activation_Type::SRELU: + nvte_cast_transpose_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(), + output_tensor.data(), output_trans_tensor.data(), + dbias_tensor.data(), workspace_tensor.data(), stream); + break; + default: + NVTE_ERROR("Unsupported ActivationEnum"); + break; + } + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasCastTransposeHandler, DActLuDBiasCastTransposeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // act_input + .Arg() // amax + .Arg() // scale + .Arg() // scale_inv + .Ret() // output + .Ret() // output_trans + .Ret() // dbias + .Ret() // amax_out + .Ret() // workspace + .Attr("act_enum"), + FFI_CudaGraph_Traits); + void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { auto *input = buffers[0]; @@ -398,7 +545,7 @@ void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *o auto m = desc.shape.dims[0]; auto n = desc.shape.dims[1]; auto act_enum = static_cast(desc.act_enum); - ; + auto input_shape = desc.shape.to_vector(); auto act_input_shape = std::vector{m, n * 2}; auto output_shape = std::vector{m, n * 2}; @@ -438,5 +585,88 @@ void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *o } } +Error_Type DGatedActLuCastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, + Buffer_Type act_input_buf, Buffer_Type amax_buf, + Buffer_Type scale_buf, Buffer_Type scale_inv_buf, + Result_Type output_buf, Result_Type output_trans_buf, + Result_Type amax_out_buf, int64_t act_enum) { + auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); + + auto *input = input_buf.untyped_data(); + auto *act_input = act_input_buf.untyped_data(); + float *amax = reinterpret_cast(amax_buf.untyped_data()); + float *scale = reinterpret_cast(scale_buf.untyped_data()); + float *scale_inv = reinterpret_cast(scale_inv_buf.untyped_data()); + auto *output = output_buf->untyped_data(); + auto *output_trans = output_trans_buf->untyped_data(); + float *amax_out = reinterpret_cast(amax_out_buf->untyped_data()); + NVTE_CHECK(amax == amax_out, + "amax not bound to amax_out in TE/JAX DGatedActLuCastTranspose primitive."); + if (!use_fp8(out_dtype)) { + scale = nullptr; + scale_inv = nullptr; + amax_out = nullptr; + } + + auto input_dims = input_buf.dimensions(); + auto act_input_dims = act_input_buf.dimensions(); + auto act_input_ranks = act_input_dims.size(); + auto m = product(act_input_dims, 0, act_input_ranks - 2); + auto n = product(act_input_dims, act_input_ranks - 1, act_input_ranks); + auto input_shape = std::vector{m, n}; + auto act_input_shape = std::vector{m, n * 2}; + auto output_shape = std::vector{m, n * 2}; + auto output_trans_shape = std::vector{n * 2, m}; + + auto input_tensor = TensorWrapper(input, input_shape, in_dtype); + auto act_input_tensor = TensorWrapper(act_input, act_input_shape, in_dtype); + auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv); + auto output_trans_tensor = + TensorWrapper(output_trans, output_trans_shape, out_dtype, amax_out, scale, scale_inv); + + auto act_type = static_cast(act_enum); + switch (act_type) { + case NVTE_Activation_Type::GEGLU: + nvte_dgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + output_trans_tensor.data(), stream); + break; + case NVTE_Activation_Type::SWIGLU: + nvte_dswiglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), + output_tensor.data(), output_trans_tensor.data(), stream); + break; + case NVTE_Activation_Type::REGLU: + nvte_dreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + output_trans_tensor.data(), stream); + break; + case NVTE_Activation_Type::QGEGLU: + nvte_dqgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), + output_tensor.data(), output_trans_tensor.data(), stream); + break; + case NVTE_Activation_Type::SREGLU: + nvte_dsreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), + output_tensor.data(), output_trans_tensor.data(), stream); + break; + default: + NVTE_ERROR("Unsupported ActivationEnum"); + break; + } + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(DGatedActLuCastTransposeHandler, DGatedActLuCastTransposeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // act_input + .Arg() // amax + .Arg() // scale + .Arg() // scale_inv + .Ret() // output + .Ret() // output_trans + .Ret() // amax_out + .Attr("act_enum"), + FFI_CudaGraph_Traits); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 90aa3f6e2b..4bde10fc46 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -30,18 +30,13 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, - common/fused_attn/fused_attn_f16_max512_seqlen.cu lines 594-634 and 773-812 - common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu lines 1270-1281 and 1348-1359 */ -void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, - const CustomCallFusedAttnDescriptor *desc, +void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t input_batch, + const size_t bias_batch, const size_t attn_heads, + const size_t bias_heads, const size_t q_max_seqlen, + const size_t kv_max_seqlen, DType dtype, NVTE_Bias_Type bias_type, NVTE_Fused_Attn_Backend backend, void *softmax_buf, void *rng_state_buf = nullptr, void *bias_buf = nullptr) { - auto input_batch = desc->input_batch; - auto bias_batch = desc->bias_batch; - auto attn_heads = desc->attn_heads; - auto bias_heads = desc->bias_heads; - auto q_max_seqlen = desc->q_max_seqlen; - auto kv_max_seqlen = desc->kv_max_seqlen; - // all backends need softmax but expect different shapes/dtypes // start with the max512 sequence length softmax shape/dtype and correct later tensor_pack->size = 1; @@ -49,7 +44,7 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, softmax_aux->data.dptr = softmax_buf; softmax_aux->data.shape = std::vector{input_batch, attn_heads, q_max_seqlen, kv_max_seqlen}; - softmax_aux->data.dtype = desc->dtype; + softmax_aux->data.dtype = dtype; // arbitrary sequence length backend needs the RNG state and a different shape/dtype softmax if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { @@ -69,7 +64,7 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, bias_aux->data.dptr = bias_buf; bias_aux->data.shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; - bias_aux->data.dtype = desc->dtype; + bias_aux->data.dtype = dtype; } } } @@ -82,22 +77,25 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, TODO(Alp): Refactor the nvte_fused_attn_fwd() to work like nvte_fused_attn_bwd()? */ -void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, - const CustomCallFusedAttnDescriptor *desc, +void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, const size_t input_batch, + const size_t bias_batch, const size_t attn_heads, + const size_t bias_heads, const size_t q_max_seqlen, + const size_t kv_max_seqlen, DType dtype, NVTE_Fused_Attn_Backend backend, void *softmax_buf, void *rng_state_buf, void *bias_buf) { // Backward calls put everything into the tensor pack for every backend // so we set dummy bias_type and backend choices here to follow the correct code path auto dummy_bias_type = NVTE_Bias_Type::NVTE_POST_SCALE_BIAS; auto dummy_backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; - PrepareFusedAttnForwardAuxTensors(tensor_pack, desc, dummy_bias_type, dummy_backend, softmax_buf, - rng_state_buf, bias_buf); + PrepareFusedAttnForwardAuxTensors(tensor_pack, input_batch, bias_batch, attn_heads, bias_heads, + q_max_seqlen, kv_max_seqlen, dtype, dummy_bias_type, + dummy_backend, softmax_buf, rng_state_buf, bias_buf); // correct softmax shape for max512 sequence length kernel if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { Tensor *softmax_aux = reinterpret_cast(tensor_pack->tensors[0]); - softmax_aux->data.shape.at(3) = desc->kv_max_seqlen; // {B,H,Qs,1} -> {B,H,Qs,Ks} - softmax_aux->data.dtype = desc->dtype; + softmax_aux->data.shape.at(3) = kv_max_seqlen; // {B,H,Qs,1} -> {B,H,Qs,Ks} + softmax_aux->data.dtype = dtype; } } @@ -187,82 +185,52 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype()); } -void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - const CustomCallFusedAttnDescriptor &descriptor = - *UnpackOpaque(opaque, opaque_len); - - auto qkv_layout = descriptor.qkv_layout; - auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; - - /* Input buffers from XLA */ - /* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */ - void *bias = buffers[3]; - void *q_cu_seqlens = buffers[4]; - void *kv_cu_seqlens = buffers[5]; - void *q_seq_offsets = is_ragged ? buffers[6] : nullptr; - void *k_seq_offsets = is_ragged ? buffers[7] : nullptr; - void *seed = buffers[8]; - - /* Output buffer from XLA */ - void *output = buffers[9]; - void *softmax_aux = buffers[10]; - void *rng_state = buffers[11]; - void *workspace = buffers[12]; +#define FUSED_ATTN_IMPL_COMMON_BLOCK \ + auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; \ + auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; \ + size_t num_segments = input_batch; \ + if (is_ragged) { \ + auto cudnn_runtime_version = cudnnGetVersion(); \ + if (cudnn_runtime_version >= 90300) { \ + num_segments = input_batch * max_segments_per_seq; \ + } else { \ + size_t runtime_num_segments_q = \ + GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream); \ + size_t runtime_num_segments_kv = \ + GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream); \ + NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv); \ + NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq); \ + num_segments = runtime_num_segments_q; \ + } \ + } \ + std::vector seq_shape{num_segments + 1}; \ + auto q_cu_seqlens_tensor = TensorWrapper(q_cu_seqlens, seq_shape, DType::kInt32); \ + auto kv_cu_seqlens_tensor = TensorWrapper(kv_cu_seqlens, seq_shape, DType::kInt32); \ + auto q_seq_offsets_tensor = TensorWrapper(q_seq_offsets, seq_shape, DType::kInt32); \ + auto k_seq_offsets_tensor = TensorWrapper(k_seq_offsets, seq_shape, DType::kInt32); \ + auto workspace_tensor = \ + TensorWrapper(workspace, std::vector{wkspace_size}, wkspace_dtype); \ + auto layout_group = nvte_get_qkv_layout_group(qkv_layout); - /* Descriptor */ - auto input_batch = descriptor.input_batch; - auto bias_batch = descriptor.bias_batch; - auto q_max_seqlen = descriptor.q_max_seqlen; - auto kv_max_seqlen = descriptor.kv_max_seqlen; - auto attn_heads = descriptor.attn_heads; - auto num_gqa_groups = descriptor.num_gqa_groups; - auto bias_heads = descriptor.bias_heads; - auto head_dim = descriptor.head_dim; - auto scaling_factor = descriptor.scaling_factor; - auto dropout_probability = descriptor.dropout_probability; - auto bias_type = descriptor.bias_type; - auto mask_type = descriptor.mask_type; - auto dtype = descriptor.dtype; - auto is_training = descriptor.is_training; - auto max_segments_per_seq = descriptor.max_segments_per_seq; - auto window_size_left = descriptor.window_size_left; - auto window_size_right = descriptor.window_size_right; +static void FusedAttnForwardImpl( + cudaStream_t stream, void *q, void *k, void *v, void *bias, void *q_cu_seqlens, + void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *seed, void *output, + void *softmax_aux, void *rng_state, void *workspace, size_t input_batch, size_t bias_batch, + size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, + size_t bias_heads, size_t head_dim, size_t max_segments_per_seq, size_t wkspace_size, + float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, + bool is_training, bool deterministic, int64_t window_size_left, int64_t window_size_right) { + FUSED_ATTN_IMPL_COMMON_BLOCK; /* Input tensors */ - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; - auto v_shape = k_shape; - auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; auto bias_tensor = TensorWrapper(bias, bias_shape, dtype); - size_t num_segments = input_batch; // Non-THD format, input_batch = num_segments if (is_ragged) { - auto cudnn_runtime_version = cudnnGetVersion(); - if (cudnn_runtime_version >= 90300) { - num_segments = input_batch * max_segments_per_seq; - } else { - // workspace can be reused here as it is not used with cuDNN graph at the same time - size_t runtime_num_segments_q = - GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream); - size_t runtime_num_segments_kv = - GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream); - NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv); - NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq); - num_segments = runtime_num_segments_q; - } - cudaMemsetAsync(output, 0, - input_batch * q_max_seqlen * attn_heads * head_dim * typeToSize(dtype), stream); + auto output_size = input_batch * q_max_seqlen * attn_heads * head_dim; + cudaMemsetAsync(output, 0, output_size * typeToSize(dtype), stream); } - auto q_cu_seqlens_tensor = - TensorWrapper(q_cu_seqlens, std::vector{num_segments + 1}, DType::kInt32); - auto kv_cu_seqlens_tensor = - TensorWrapper(kv_cu_seqlens, std::vector{num_segments + 1}, DType::kInt32); - auto q_seq_offsets_tensor = - TensorWrapper(q_seq_offsets, std::vector{num_segments + 1}, DType::kInt32); - auto k_seq_offsets_tensor = - TensorWrapper(k_seq_offsets, std::vector{num_segments + 1}, DType::kInt32); - /* Output tensors */ auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); // not used in F16 auto o_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; @@ -279,32 +247,25 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s /* Auxiliary tensors (to be propagated to the backward pass later) */ NVTETensorPack aux_output_tensors; nvte_tensor_pack_create(&aux_output_tensors); - PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend, - softmax_aux); - - /* cuDNN workspace */ - auto workspace_tensor = TensorWrapper(workspace, std::vector{descriptor.wkspace_size}, - descriptor.wkspace_dtype); + PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, input_batch, bias_batch, attn_heads, + bias_heads, q_max_seqlen, kv_max_seqlen, dtype, bias_type, + backend, softmax_aux); - /* Call the underly NVTE API */ - auto layout_group = nvte_get_qkv_layout_group(qkv_layout); + /* Call the underlying NVTE API */ if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - auto qkv = buffers[0]; auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; - auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); - nvte_fused_attn_fwd_qkvpacked( - qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), - &aux_output_tensors, q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), - rng_state_tensor.data(), q_max_seqlen, is_training, descriptor.scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - workspace_tensor.data(), stream); + auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); + nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), + o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), + q_seq_offsets_tensor.data(), rng_state_tensor.data(), + q_max_seqlen, is_training, scaling_factor, dropout_probability, + qkv_layout, bias_type, mask_type, window_size_left, + window_size_right, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - auto q = buffers[0]; auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto kv = buffers[1]; auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; - auto kv_tensor = TensorWrapper(kv, kv_shape, dtype); + auto q_tensor = TensorWrapper(q, q_shape, dtype); + auto kv_tensor = TensorWrapper(k, kv_shape, dtype); nvte_fused_attn_fwd_kvpacked( q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), @@ -312,14 +273,11 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { - auto q = buffers[0]; auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto k = buffers[1]; auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; - auto k_tensor = TensorWrapper(k, k_shape, dtype); - auto v = buffers[2]; auto v_shape = k_shape; + auto q_tensor = TensorWrapper(q, q_shape, dtype); + auto k_tensor = TensorWrapper(k, k_shape, dtype); auto v_tensor = TensorWrapper(v, v_shape, dtype); nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, @@ -335,6 +293,108 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s nvte_tensor_pack_destroy(&aux_output_tensors); } +void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { + const CustomCallFusedAttnDescriptor &descriptor = + *UnpackOpaque(opaque, opaque_len); + auto is_ragged = nvte_get_qkv_format(descriptor.qkv_layout) == NVTE_QKV_Format::NVTE_THD; + + /* Input buffers from XLA */ + void *q = buffers[0]; + void *k = buffers[1]; + void *v = buffers[2]; + void *bias = buffers[3]; + void *q_cu_seqlens = buffers[4]; + void *kv_cu_seqlens = buffers[5]; + void *q_seq_offsets = is_ragged ? buffers[6] : nullptr; + void *k_seq_offsets = is_ragged ? buffers[7] : nullptr; + void *seed = buffers[8]; + + /* Output buffer from XLA */ + void *output = buffers[9]; + void *softmax_aux = buffers[10]; + void *rng_state = buffers[11]; + void *workspace = buffers[12]; + + FusedAttnForwardImpl( + stream, q, k, v, bias, q_cu_seqlens, kv_cu_seqlens, q_seq_offsets, k_seq_offsets, seed, + output, softmax_aux, rng_state, workspace, descriptor.input_batch, descriptor.bias_batch, + descriptor.q_max_seqlen, descriptor.kv_max_seqlen, descriptor.attn_heads, + descriptor.num_gqa_groups, descriptor.bias_heads, descriptor.head_dim, + descriptor.max_segments_per_seq, descriptor.wkspace_size, descriptor.scaling_factor, + descriptor.dropout_probability, descriptor.bias_type, descriptor.mask_type, + descriptor.qkv_layout, descriptor.dtype, descriptor.wkspace_dtype, descriptor.is_training, + descriptor.deterministic, descriptor.window_size_left, descriptor.window_size_right); +} + +#define FUSED_ATTN_FFI_GET_ATTRS \ + size_t input_batch = get_attr_value(attrs, "input_batch"); \ + size_t bias_batch = get_attr_value(attrs, "bias_batch"); \ + size_t q_max_seqlen = get_attr_value(attrs, "q_max_seqlen"); \ + size_t kv_max_seqlen = get_attr_value(attrs, "kv_max_seqlen"); \ + size_t attn_heads = get_attr_value(attrs, "attn_heads"); \ + size_t num_gqa_groups = get_attr_value(attrs, "num_gqa_groups"); \ + size_t bias_heads = get_attr_value(attrs, "bias_heads"); \ + size_t head_dim = get_attr_value(attrs, "head_dim"); \ + size_t max_segments_per_seq = get_attr_value(attrs, "max_segments_per_seq"); \ + auto window_size_left = get_attr_value(attrs, "window_size_left"); \ + auto window_size_right = get_attr_value(attrs, "window_size_right"); \ + float scaling_factor = get_attr_value(attrs, "scaling_factor"); \ + float dropout_probability = get_attr_value(attrs, "dropout_probability"); \ + NVTE_Bias_Type bias_type = \ + static_cast(get_attr_value(attrs, "bias_type")); \ + NVTE_Mask_Type mask_type = \ + static_cast(get_attr_value(attrs, "mask_type")); \ + NVTE_QKV_Layout qkv_layout = \ + static_cast(get_attr_value(attrs, "qkv_layout")); \ + bool is_training = get_attr_value(attrs, "is_training"); \ + bool deterministic = get_attr_value(attrs, "deterministic"); \ + auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; \ + size_t wkspace_size = product(workspace_buf->dimensions()); \ + DType dtype = convert_ffi_datatype_to_te_dtype(q_buf.element_type()); \ + DType wkspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); + +Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, + Buffer_Type v_buf, Buffer_Type bias_buf, + Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf, + Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf, + Buffer_Type seed_buf, Result_Type output_buf, + Result_Type softmax_aux_buf, Result_Type rng_state_buf, + Result_Type workspace_buf, Dictionary attrs) { + FUSED_ATTN_FFI_GET_ATTRS; + + FusedAttnForwardImpl( + stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(), + bias_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(), kv_cu_seqlens_buf.untyped_data(), + is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr, + is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, seed_buf.untyped_data(), + output_buf->untyped_data(), softmax_aux_buf->untyped_data(), rng_state_buf->untyped_data(), + workspace_buf->untyped_data(), input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, + attn_heads, num_gqa_groups, bias_heads, head_dim, max_segments_per_seq, wkspace_size, + scaling_factor, dropout_probability, bias_type, mask_type, qkv_layout, dtype, wkspace_dtype, + is_training, deterministic, window_size_left, window_size_right); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnForwardHandler, FusedAttnForwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // q + .Arg() // k + .Arg() // v + .Arg() // bias + .Arg() // q_cu_seqlens + .Arg() // kv_cu_seqlens + .Arg() // q_seq_offsets + .Arg() // k_seq_offsets + .Arg() // seed_buf + .Ret() // output + .Ret() // softmax_aux + .Ret() // rng_state + .Ret() // workspace + .Attrs(), + FFI_CudaGraph_Traits); + pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, @@ -437,81 +497,23 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); } -void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - const CustomCallFusedAttnDescriptor &descriptor = - *UnpackOpaque(opaque, opaque_len); - - auto qkv_layout = descriptor.qkv_layout; - auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; - - /* Input buffers from XLA */ - /* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */ - void *bias = buffers[3]; - void *softmax_aux = buffers[4]; - void *rng_state = buffers[5]; - void *output = buffers[6]; - void *doutput = buffers[7]; - void *q_cu_seqlens = buffers[8]; - void *kv_cu_seqlens = buffers[9]; - void *q_seq_offsets = is_ragged ? buffers[10] : nullptr; - void *k_seq_offsets = is_ragged ? buffers[11] : nullptr; - - /* Output buffer from XLA */ - /* Buffers[12-14] are dq, dk, dv, which are parsed later for different qkv_layout */ - void *dbias = buffers[15]; - void *workspace = buffers[16]; - - /* Descriptor */ - auto input_batch = descriptor.input_batch; - auto bias_batch = descriptor.bias_batch; - auto q_max_seqlen = descriptor.q_max_seqlen; - auto kv_max_seqlen = descriptor.kv_max_seqlen; - auto attn_heads = descriptor.attn_heads; - auto num_gqa_groups = descriptor.num_gqa_groups; - auto bias_heads = descriptor.bias_heads; - auto head_dim = descriptor.head_dim; - auto scaling_factor = descriptor.scaling_factor; - auto dropout_probability = descriptor.dropout_probability; - auto bias_type = descriptor.bias_type; - auto mask_type = descriptor.mask_type; - auto dtype = descriptor.dtype; - auto deterministic = descriptor.deterministic; - auto max_segments_per_seq = descriptor.max_segments_per_seq; - auto window_size_left = descriptor.window_size_left; - auto window_size_right = descriptor.window_size_right; +static void FusedAttnBackwardImpl( + cudaStream_t stream, void *q, void *k, void *v, void *bias, void *softmax_aux, void *rng_state, + void *output, void *doutput, void *q_cu_seqlens, void *kv_cu_seqlens, void *q_seq_offsets, + void *k_seq_offsets, void *dq, void *dk, void *dv, void *dbias, void *workspace, + size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, + size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, + size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor, + float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, + bool deterministic, int64_t window_size_left, int64_t window_size_right) { + FUSED_ATTN_IMPL_COMMON_BLOCK; /* Input tensors */ auto output_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; auto output_tensor = TensorWrapper(output, output_shape, dtype); auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype); - size_t num_segments = input_batch; // Non-THD format, input_batch = num_segments - if (is_ragged) { - auto cudnn_runtime_version = cudnnGetVersion(); - if (cudnn_runtime_version >= 90300) { - num_segments = input_batch * max_segments_per_seq; - } else { - // workspace can be reused here as it is not used with cuDNN graph at the same time - size_t runtime_num_segments_q = - GetRuntimeNumSegments(q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream); - size_t runtime_num_segments_kv = - GetRuntimeNumSegments(kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream); - NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv); - NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq); - num_segments = runtime_num_segments_q; - } - } - - auto q_cu_seqlens_tensor = - TensorWrapper(q_cu_seqlens, std::vector{num_segments + 1}, DType::kInt32); - auto kv_cu_seqlens_tensor = - TensorWrapper(kv_cu_seqlens, std::vector{num_segments + 1}, DType::kInt32); - auto q_seq_offsets_tensor = - TensorWrapper(q_seq_offsets, std::vector{num_segments + 1}, DType::kInt32); - auto k_seq_offsets_tensor = - TensorWrapper(k_seq_offsets, std::vector{num_segments + 1}, DType::kInt32); - /* Output tensors */ auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); // not used in F16 auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype); @@ -523,26 +525,17 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, head_dim, head_dim, window_size_left, window_size_right); - PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux, - rng_state, bias); - - /* cuDNN workspace */ - auto wkspace_size = std::vector{descriptor.wkspace_size}; - auto wkspace_dtype = descriptor.wkspace_dtype; - auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype); + PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, + bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, + softmax_aux, rng_state, bias); /* Call the underly NVTE API */ - auto layout_group = nvte_get_qkv_layout_group(qkv_layout); if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - auto qkv = buffers[0]; auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; - auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); - auto dqkv = buffers[12]; - auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype); + auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); + auto dqkv_tensor = TensorWrapper(dq, qkv_shape, dtype); if (is_ragged) { - size_t dqkv_size = - std::accumulate(qkv_shape.cbegin(), qkv_shape.cend(), 1, std::multiplies()); - cudaMemsetAsync(dqkv, 0, dqkv_size * typeToSize(dtype), stream); + cudaMemsetAsync(dq, 0, transformer_engine::product(qkv_shape) * typeToSize(dtype), stream); } nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), s_tensor.data(), // not used for F16 @@ -553,23 +546,15 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, bias_type, mask_type, window_size_left, window_size_right, deterministic, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - auto q = buffers[0]; auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto kv = buffers[1]; auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; - auto kv_tensor = TensorWrapper(kv, kv_shape, dtype); - auto dq = buffers[12]; + auto q_tensor = TensorWrapper(q, q_shape, dtype); + auto kv_tensor = TensorWrapper(k, kv_shape, dtype); auto dq_tensor = TensorWrapper(dq, q_shape, dtype); - auto dkv = buffers[13]; - auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype); + auto dkv_tensor = TensorWrapper(dk, kv_shape, dtype); if (is_ragged) { - size_t dq_size = - std::accumulate(q_shape.cbegin(), q_shape.cend(), 1, std::multiplies()); - size_t dkv_size = - std::accumulate(kv_shape.cbegin(), kv_shape.cend(), 1, std::multiplies()); - cudaMemsetAsync(dq, 0, dq_size * typeToSize(dtype), stream); - cudaMemsetAsync(dkv, 0, dkv_size * typeToSize(dtype), stream); + cudaMemsetAsync(dq, 0, transformer_engine::product(q_shape) * typeToSize(dtype), stream); + cudaMemsetAsync(dk, 0, transformer_engine::product(kv_shape) * typeToSize(dtype), stream); } nvte_fused_attn_bwd_kvpacked( q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), @@ -581,30 +566,19 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { - auto q = buffers[0]; auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto k = buffers[1]; auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; - auto k_tensor = TensorWrapper(k, k_shape, dtype); - auto v = buffers[2]; auto v_shape = k_shape; + auto q_tensor = TensorWrapper(q, q_shape, dtype); + auto k_tensor = TensorWrapper(k, k_shape, dtype); auto v_tensor = TensorWrapper(v, v_shape, dtype); - auto dq = buffers[12]; auto dq_tensor = TensorWrapper(dq, q_shape, dtype); - auto dk = buffers[13]; auto dk_tensor = TensorWrapper(dk, k_shape, dtype); - auto dv = buffers[14]; auto dv_tensor = TensorWrapper(dv, v_shape, dtype); if (is_ragged) { - size_t dq_size = - std::accumulate(q_shape.cbegin(), q_shape.cend(), 1, std::multiplies()); - size_t dk_size = - std::accumulate(k_shape.cbegin(), k_shape.cend(), 1, std::multiplies()); - size_t dv_size = dk_size; - cudaMemsetAsync(dq, 0, dq_size * typeToSize(dtype), stream); - cudaMemsetAsync(dk, 0, dk_size * typeToSize(dtype), stream); - cudaMemsetAsync(dv, 0, dv_size * typeToSize(dtype), stream); + cudaMemsetAsync(dq, 0, transformer_engine::product(q_shape) * typeToSize(dtype), stream); + cudaMemsetAsync(dk, 0, transformer_engine::product(k_shape) * typeToSize(dtype), stream); + cudaMemsetAsync(dv, 0, transformer_engine::product(v_shape) * typeToSize(dtype), stream); } nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), doutput_tensor.data(), @@ -623,5 +597,93 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, nvte_tensor_pack_destroy(&aux_input_tensors); } +void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { + const CustomCallFusedAttnDescriptor &descriptor = + *UnpackOpaque(opaque, opaque_len); + + auto qkv_layout = descriptor.qkv_layout; + auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; + + /* Input buffers from XLA */ + void *q = buffers[0]; + void *k = buffers[1]; + void *v = buffers[2]; + void *bias = buffers[3]; + void *softmax_aux = buffers[4]; + void *rng_state = buffers[5]; + void *output = buffers[6]; + void *doutput = buffers[7]; + void *q_cu_seqlens = buffers[8]; + void *kv_cu_seqlens = buffers[9]; + void *q_seq_offsets = is_ragged ? buffers[10] : nullptr; + void *k_seq_offsets = is_ragged ? buffers[11] : nullptr; + + /* Output buffer from XLA */ + void *dq = buffers[12]; + void *dk = buffers[13]; + void *dv = buffers[14]; + void *dbias = buffers[15]; + void *workspace = buffers[16]; + + FusedAttnBackwardImpl( + stream, q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlens, kv_cu_seqlens, + q_seq_offsets, k_seq_offsets, dq, dk, dv, dbias, workspace, descriptor.input_batch, + descriptor.bias_batch, descriptor.q_max_seqlen, descriptor.kv_max_seqlen, + descriptor.attn_heads, descriptor.num_gqa_groups, descriptor.bias_heads, descriptor.head_dim, + descriptor.max_segments_per_seq, descriptor.wkspace_size, descriptor.scaling_factor, + descriptor.dropout_probability, descriptor.bias_type, descriptor.mask_type, + descriptor.qkv_layout, descriptor.dtype, descriptor.wkspace_dtype, descriptor.is_training, + descriptor.deterministic, descriptor.window_size_left, descriptor.window_size_right); +} + +Error_Type FusedAttnBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, + Buffer_Type v_buf, Buffer_Type bias_buf, + Buffer_Type softmax_aux_buf, Buffer_Type rng_state_buf, + Buffer_Type output_buf, Buffer_Type doutput_buf, + Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf, + Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf, + Result_Type dq_buf, Result_Type dk_buf, Result_Type dv_buf, + Result_Type dbias_buf, Result_Type workspace_buf, + Dictionary attrs) { + FUSED_ATTN_FFI_GET_ATTRS; + + FusedAttnBackwardImpl( + stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(), + bias_buf.untyped_data(), softmax_aux_buf.untyped_data(), rng_state_buf.untyped_data(), + output_buf.untyped_data(), doutput_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(), + kv_cu_seqlens_buf.untyped_data(), is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr, + is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, dq_buf->untyped_data(), + dk_buf->untyped_data(), dv_buf->untyped_data(), dbias_buf->untyped_data(), + workspace_buf->untyped_data(), input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, + attn_heads, num_gqa_groups, bias_heads, head_dim, max_segments_per_seq, wkspace_size, + scaling_factor, dropout_probability, bias_type, mask_type, qkv_layout, dtype, wkspace_dtype, + is_training, deterministic, window_size_left, window_size_right); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnBackwardHandler, FusedAttnBackwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // q + .Arg() // k + .Arg() // v + .Arg() // bias + .Arg() // softmax_aux + .Arg() // rng_state + .Arg() // output + .Arg() // doutput + .Arg() // q_cu_seqlens + .Arg() // kv_cu_seqlens + .Arg() // q_seq_offsets + .Arg() // k_seq_offsets + .Ret() // dq + .Ret() // dk + .Ret() // dv + .Ret() // dbias + .Ret() // workspace + .Attrs(), + FFI_CudaGraph_Traits); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/cudnn.cpp b/transformer_engine/jax/csrc/extensions/cudnn.cpp new file mode 100644 index 0000000000..95f505e226 --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/cudnn.cpp @@ -0,0 +1,24 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "transformer_engine/cudnn.h" + +#include "extensions.h" +#include "xla/ffi/api/c_api.h" + +namespace transformer_engine { +namespace jax { + +Error_Type CudnnHandleInitFFI(Variadic_Buffer_Type args, Variadic_Result_Type rets, + Dictionary attrs) { + nvte_cudnn_handle_init(); + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(CudnnHandleInitHandler, CudnnHandleInitFFI, + FFI::Bind().RemainingArgs().RemainingRets().Attrs()); +} // namespace jax +} // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/ffi.cpp b/transformer_engine/jax/csrc/extensions/ffi.cpp index 19fd50cbd1..8b627aad35 100644 --- a/transformer_engine/jax/csrc/extensions/ffi.cpp +++ b/transformer_engine/jax/csrc/extensions/ffi.cpp @@ -7,20 +7,27 @@ #include -#include "common/util/logging.h" - namespace transformer_engine { namespace jax { // For XLA_FFI_DataType Enum Reference: https://github.com/openxla/xla/blob/d054e8366c4e8807726961feeb28b1cdba681888/xla/ffi/api/c_api.h#L163-L186 DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) { switch (type) { - case xla::ffi::DataType::F16: - return DType::kFloat16; + case xla::ffi::DataType::U8: + return DType::kByte; + break; + case xla::ffi::DataType::S32: + return DType::kInt32; + break; + case xla::ffi::DataType::S64: + return DType::kInt64; break; case xla::ffi::DataType::F32: return DType::kFloat32; break; + case xla::ffi::DataType::F16: + return DType::kFloat16; + break; case xla::ffi::DataType::BF16: return DType::kBFloat16; break; diff --git a/transformer_engine/jax/csrc/extensions/ffi.h b/transformer_engine/jax/csrc/extensions/ffi.h index 77132c3fca..d886064cae 100644 --- a/transformer_engine/jax/csrc/extensions/ffi.h +++ b/transformer_engine/jax/csrc/extensions/ffi.h @@ -9,17 +9,77 @@ #include +#include "common/util/logging.h" + namespace transformer_engine { namespace jax { using Buffer_Type = xla::ffi::AnyBuffer; using Result_Type = xla::ffi::Result; +using Variadic_Buffer_Type = xla::ffi::RemainingArgs; +using Variadic_Result_Type = xla::ffi::RemainingRets; using Error_Type = xla::ffi::Error; using FFI = xla::ffi::Ffi; using FFI_Stream_Type = xla::ffi::PlatformStream; +using Dictionary = xla::ffi::Dictionary; + +constexpr auto FFI_Prepare = xla::ffi::ExecutionStage::kPrepare; +constexpr auto FFI_CudaGraph_Traits = {xla::ffi::Traits::kCmdBufferCompatible}; + +DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType& type); -DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type); Error_Type ffi_with_cuda_error_check(); +// source_location is not available in C++17, so we implement it ourselves +#if defined(__GNUC__) || defined(__clang__) +#define CURRENT_FILE __builtin_FILE() +#define CURRENT_LINE __builtin_LINE() +#define CURRENT_FUNCTION __builtin_FUNCTION() +#else +#define CURRENT_FILE __FILE__ +#define CURRENT_LINE __LINE__ +#define CURRENT_FUNCTION __func__ +#endif + +class source_location { + public: + static source_location current(const char* file = CURRENT_FILE, int line = CURRENT_LINE, + const char* function = CURRENT_FUNCTION) { + return source_location(file, line, function); + } + + constexpr const char* file_name() const { return file_; } + constexpr int line() const { return line_; } + constexpr const char* function_name() const { return function_; } + + private: + constexpr source_location(const char* file, int line, const char* function) + : file_(file), line_(line), function_(function) {} + + const char* file_; + int line_; + const char* function_; +}; + +template +T get_attr_value(Dictionary& attrs, std::string attr_name, + const source_location& loc = source_location::current()) { + auto attr = attrs.get(attr_name); + if (attr.has_error()) { + NVTE_ERROR("Failure in getting attribute value of '", attr_name, "'\n", + "Called from: ", loc.file_name(), ":", loc.line(), "\n", + "In function: ", loc.function_name(), "\n", + "Please ensure the attribute name and datatype match between C++ and Python APIs."); + } + return attr.value(); +} + +inline size_t product(const xla::ffi::Span& data, size_t start_idx = 0, + size_t end_idx = 0) { + end_idx = (end_idx == 0) ? data.size() : end_idx; + return std::accumulate(data.begin() + start_idx, data.begin() + end_idx, size_t(1), + std::multiplies()); +} + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/normalization.cpp b/transformer_engine/jax/csrc/extensions/normalization.cpp index fb40400e62..9bd9951916 100644 --- a/transformer_engine/jax/csrc/extensions/normalization.cpp +++ b/transformer_engine/jax/csrc/extensions/normalization.cpp @@ -91,6 +91,200 @@ void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspac } } +Error_Type LayerNormForwardImplFFI(cudaStream_t stream, Buffer_Type *x_buf, Buffer_Type *gamma_buf, + Buffer_Type *beta_buf, Buffer_Type *amax_buf, + Buffer_Type *scale_buf, Buffer_Type *scale_inv_buf, + Result_Type *output_buf, Result_Type *mu_buf, + Result_Type *rsigma_buf, Result_Type *amax_out_buf, + Result_Type *wkspace_buf, Result_Type *barrier_buf, + bool zero_centered_gamma, double eps_, int64_t sm_margin_, + bool is_layer_norm, bool is_fp8) { + auto in_dtype = convert_ffi_datatype_to_te_dtype((*x_buf).element_type()); + auto w_dtype = convert_ffi_datatype_to_te_dtype((*gamma_buf).element_type()); + auto wkspace_dtype = convert_ffi_datatype_to_te_dtype((*wkspace_buf)->element_type()); + auto barrier_dtype = convert_ffi_datatype_to_te_dtype((*barrier_buf)->element_type()); + + auto *input = x_buf->untyped_data(); + auto *weight = gamma_buf->untyped_data(); + auto *output = (*output_buf)->untyped_data(); + auto *rsigma = (*rsigma_buf)->untyped_data(); + auto *workspace = (*wkspace_buf)->untyped_data(); + auto *barrier = (*barrier_buf)->untyped_data(); + + void *bias = nullptr; + void *mu = nullptr; + if (is_layer_norm) { + bias = beta_buf->untyped_data(); + mu = (*mu_buf)->untyped_data(); + } + + float *amax = nullptr; + float *scale = nullptr; + float *scale_inv = nullptr; + void *amax_out = nullptr; + auto out_dtype = in_dtype; + if (is_fp8) { + amax = reinterpret_cast(amax_buf->untyped_data()); + scale = reinterpret_cast(scale_buf->untyped_data()); + scale_inv = reinterpret_cast(scale_inv_buf->untyped_data()); + amax_out = (*amax_out_buf)->untyped_data(); + NVTE_CHECK(amax_out == amax, "amax not bound to amax_out in TE/JAX LayerNormForward primitive"); + out_dtype = DType::kFloat8E4M3; + } + + auto x_size = product(x_buf->dimensions()); + auto gamma_size = product(gamma_buf->dimensions()); + auto wkspace_size = product((*wkspace_buf)->dimensions()); + auto barrier_size = product((*barrier_buf)->dimensions()); + auto hidden_size = gamma_size; + auto batch_size = x_size / gamma_size; + + float eps = static_cast(eps_); + int sm_margin = static_cast(sm_margin_); + + LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, + eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, + wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, + sm_margin, stream); + return ffi_with_cuda_error_check(); +} + +Error_Type LayerNormForwardFP8FFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf, + Buffer_Type beta_buf, Buffer_Type amax_buf, Buffer_Type scale_buf, + Buffer_Type scale_inv_buf, Result_Type output_buf, + Result_Type mu_buf, Result_Type rsigma_buf, + Result_Type amax_out_buf, Result_Type wkspace_buf, + Result_Type barrier_buf, bool zero_centered_gamma, double eps_, + int64_t sm_margin_) { + return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf, &beta_buf, &amax_buf, &scale_buf, + &scale_inv_buf, &output_buf, &mu_buf, &rsigma_buf, &amax_out_buf, + &wkspace_buf, &barrier_buf, zero_centered_gamma, eps_, sm_margin_, + true, // is_layer_norm + true // is_fp8 + ); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormForwardFP8Handler, LayerNormForwardFP8FFI, + FFI::Bind() + .Ctx() // stream + .Arg() // x + .Arg() // gamma + .Arg() // beta + .Arg() // amax + .Arg() // scale + .Arg() // scale_inv + .Ret() // output + .Ret() // mu + .Ret() // rsigma + .Ret() // amax_out + .Ret() // wkspace + .Ret() // barrier + .Attr("zero_centered_gamma") + .Attr("eps") + .Attr("sm_margin"), + FFI_CudaGraph_Traits); + +Error_Type LayerNormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf, + Buffer_Type beta_buf, Result_Type output_buf, Result_Type mu_buf, + Result_Type rsigma_buf, Result_Type wkspace_buf, + Result_Type barrier_buf, bool zero_centered_gamma, double eps_, + int64_t sm_margin_) { + return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf, &beta_buf, + nullptr, // amax_buf + nullptr, // scale_buf, + nullptr, // scale_inv_buf, + &output_buf, &mu_buf, &rsigma_buf, + nullptr, // amax_out_buf, + &wkspace_buf, &barrier_buf, zero_centered_gamma, eps_, sm_margin_, + true, // is_layer_norm + false // is_fp8 + ); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormForwardHandler, LayerNormForwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // x + .Arg() // gamma + .Arg() // beta + .Ret() // output + .Ret() // mu + .Ret() // rsigma + .Ret() // wkspace + .Ret() // barrier + .Attr("zero_centered_gamma") + .Attr("eps") + .Attr("sm_margin"), + FFI_CudaGraph_Traits); + +Error_Type RMSNormForwardFP8FFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf, + Buffer_Type amax_buf, Buffer_Type scale_buf, + Buffer_Type scale_inv_buf, Result_Type output_buf, + Result_Type rsigma_buf, Result_Type amax_out_buf, + Result_Type wkspace_buf, Result_Type barrier_buf, + bool zero_centered_gamma, double eps_, int64_t sm_margin_) { + return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf, + nullptr, // beta_buf, + &amax_buf, &scale_buf, &scale_inv_buf, &output_buf, + nullptr, // mu_buf, + &rsigma_buf, &amax_out_buf, &wkspace_buf, &barrier_buf, + zero_centered_gamma, eps_, sm_margin_, + false, // is_layer_norm + true // is_fp8 + ); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(RMSNormForwardFP8Handler, RMSNormForwardFP8FFI, + FFI::Bind() + .Ctx() // stream + .Arg() // x + .Arg() // gamma + .Arg() // amax + .Arg() // scale + .Arg() // scale_inv + .Ret() // output + .Ret() // rsigma + .Ret() // amax_out + .Ret() // wkspace + .Ret() // barrier + .Attr("zero_centered_gamma") + .Attr("eps") + .Attr("sm_margin"), + FFI_CudaGraph_Traits); + +Error_Type RMSNormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf, + Result_Type output_buf, Result_Type rsigma_buf, + Result_Type wkspace_buf, Result_Type barrier_buf, + bool zero_centered_gamma, double eps_, int64_t sm_margin_) { + return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf, + nullptr, // beta_buf, + nullptr, // amax_buf, + nullptr, // scale_buf, + nullptr, // scale_inv_buf, + &output_buf, + nullptr, // mu_buf, + &rsigma_buf, + nullptr, // amax_out_buf, + &wkspace_buf, &barrier_buf, zero_centered_gamma, eps_, sm_margin_, + false, // is_layer_norm + false // is_fp8 + ); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(RMSNormForwardHandler, RMSNormForwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // x + .Arg() // gamma + .Ret() // output + .Ret() // rsigma + .Ret() // wkspace + .Ret() // barrier + .Attr("zero_centered_gamma") + .Attr("eps") + .Attr("sm_margin"), + FFI_CudaGraph_Traits); + pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, bool is_layer_norm, bool zero_centered_gamma, @@ -199,6 +393,140 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace } } +Error_Type LayerNormBackwardImplFFI(cudaStream_t stream, Buffer_Type *dz_buf, Buffer_Type *x_buf, + Buffer_Type *mu_buf, Buffer_Type *rsigma_buf, + Buffer_Type *gamma_buf, Result_Type *xgrad_buf, + Result_Type *wgrad_buf, Result_Type *dbeta_buf, + Result_Type *wkspace_buf, Result_Type *barrier_buf, + Result_Type *dgamma_part_buf, Result_Type *dbeta_part_buf, + bool zero_centered_gamma, double eps_, int64_t sm_margin_, + bool is_layer_norm) { + auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf->element_type()); + auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf->element_type()); + auto wkspace_dtype = convert_ffi_datatype_to_te_dtype((*wkspace_buf)->element_type()); + auto barrier_dtype = convert_ffi_datatype_to_te_dtype((*barrier_buf)->element_type()); + auto dgamma_part_dtype = convert_ffi_datatype_to_te_dtype((*dgamma_part_buf)->element_type()); + + auto *ograd = dz_buf->untyped_data(); + auto *rsigma = rsigma_buf->untyped_data(); + auto *input = x_buf->untyped_data(); + auto *weight = gamma_buf->untyped_data(); + auto *xgrad = (*xgrad_buf)->untyped_data(); + auto *wgrad = (*wgrad_buf)->untyped_data(); + auto *workspace = (*wkspace_buf)->untyped_data(); + auto *barrier = (*barrier_buf)->untyped_data(); + auto *dgamma_part = (*dgamma_part_buf)->untyped_data(); + + void *mu = nullptr; + void *dbeta = nullptr; + void *dbeta_part = nullptr; + auto dbeta_part_dtype = DType::kByte; + if (is_layer_norm) { + mu = (*mu_buf).untyped_data(); + dbeta = (*dbeta_buf)->untyped_data(); + dbeta_part = (*dbeta_part_buf)->untyped_data(); + dbeta_part_dtype = convert_ffi_datatype_to_te_dtype((*dbeta_part_buf)->element_type()); + } + + auto x_size = product(x_buf->dimensions()); + auto gamma_size = product(gamma_buf->dimensions()); + auto wkspace_size = product((*wkspace_buf)->dimensions()); + auto barrier_size = product((*barrier_buf)->dimensions()); + auto hidden_size = gamma_size; + auto batch_size = x_size / gamma_size; + + Shape dgamma_part_shape; + auto dgamma_part_dims = (*dgamma_part_buf)->dimensions(); + std::vector dgamma_parts_dims_vector(dgamma_part_dims.begin(), dgamma_part_dims.end()); + dgamma_part_shape.from_vector(dgamma_parts_dims_vector); + + Shape dbeta_part_shape; + if (is_layer_norm) { + auto dbeta_part_dims = (*dbeta_part_buf)->dimensions(); + std::vector dbeta_parts_dims_vector(dbeta_part_dims.begin(), dbeta_part_dims.end()); + dbeta_part_shape.from_vector(dbeta_parts_dims_vector); + } else { + dbeta_part_shape.from_vector({0, 0}); + } + + float eps = static_cast(eps_); + int sm_margin = static_cast(sm_margin_); + + LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape, + dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight, + w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu, + rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part, + dbeta_part_dtype, sm_margin, stream); + return ffi_with_cuda_error_check(); +} + +Error_Type LayerNormBackwardFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf, + Buffer_Type mu_buf, Buffer_Type rsigma_buf, Buffer_Type gamma_buf, + Result_Type xgrad_buf, Result_Type wgrad_buf, Result_Type dbeta_buf, + Result_Type wkspace_buf, Result_Type barrier_buf, + Result_Type dgamma_part_buf, Result_Type dbeta_part_buf, + bool zero_centered_gamma, double eps_, int64_t sm_margin_) { + return LayerNormBackwardImplFFI(stream, &dz_buf, &x_buf, &mu_buf, &rsigma_buf, &gamma_buf, + &xgrad_buf, &wgrad_buf, &dbeta_buf, &wkspace_buf, &barrier_buf, + &dgamma_part_buf, &dbeta_part_buf, zero_centered_gamma, eps_, + sm_margin_, + true // is_layer_norm + ); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormBackwardHandler, LayerNormBackwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // dz + .Arg() // x + .Arg() // mu + .Arg() // rsigma + .Arg() // gamma + .Ret() // xgrad + .Ret() // wgrad + .Ret() // dbeta + .Ret() // wkspace + .Ret() // barrier + .Ret() // dgamma_part + .Ret() // dbeta_part + .Attr("zero_centered_gamma") + .Attr("eps") + .Attr("sm_margin"), + FFI_CudaGraph_Traits); + +Error_Type RMSNormBackwardFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf, + Buffer_Type rsigma_buf, Buffer_Type gamma_buf, Result_Type xgrad_buf, + Result_Type wgrad_buf, Result_Type wkspace_buf, + Result_Type barrier_buf, Result_Type dgamma_part_buf, + bool zero_centered_gamma, double eps_, int64_t sm_margin_) { + return LayerNormBackwardImplFFI(stream, &dz_buf, &x_buf, + nullptr, // mu_buf + &rsigma_buf, &gamma_buf, &xgrad_buf, &wgrad_buf, + nullptr, // dbeta_buf, + &wkspace_buf, &barrier_buf, &dgamma_part_buf, + nullptr, // dbeta_part_buf, + zero_centered_gamma, eps_, sm_margin_, + false // is_layer_norm + ); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(RMSNormBackwardHandler, RMSNormBackwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // dz + .Arg() // x + .Arg() // rsigma + .Arg() // gamma + .Ret() // xgrad + .Ret() // wgrad + .Ret() // wkspace + .Ret() // barrier + .Ret() // dgamma_part + .Attr("zero_centered_gamma") + .Attr("eps") + .Attr("sm_margin"), + FFI_CudaGraph_Traits); + void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { auto *input = buffers[0]; diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 14f449a76b..9b5c156e5d 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -52,9 +52,55 @@ pybind11::dict Registrations() { dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward); dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward); + // Transpose + dict["te_transpose_ffi"] = EncapsulateFFI(TransposeHandler); dict["te_cast_transpose_ffi"] = EncapsulateFFI(CastTransposeHandler); + dict["te_dbias_cast_transpose_ffi"] = EncapsulateFFI(DBiasCastTransposeHandler); + + // Activation dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler); + dict["te_act_lu_fp8_ffi"] = EncapsulateFFI(ActLuFP8Handler); dict["te_dact_lu_ffi"] = EncapsulateFFI(DActLuHandler); + dict["te_dact_lu_dbias_cast_transpose_ffi"] = + EncapsulateFunction(DActLuDBiasCastTransposeHandler); + dict["te_dgated_act_lu_cast_transpose_ffi"] = + EncapsulateFunction(DGatedActLuCastTransposeHandler); + + // Quantization + dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler); + dict["te_dequantize_ffi"] = EncapsulateFFI(DequantizeHandler); + + // Softmax + dict["te_scaled_softmax_forward_ffi"] = EncapsulateFunction(ScaledSoftmaxForwardHandler); + dict["te_scaled_softmax_backward_ffi"] = EncapsulateFunction(ScaledSoftmaxBackwardHandler); + dict["te_scaled_masked_softmax_forward_ffi"] = + EncapsulateFunction(ScaledMaskedSoftmaxForwardHandler); + dict["te_scaled_masked_softmax_backward_ffi"] = + EncapsulateFunction(ScaledMaskedSoftmaxBackwardHandler); + dict["te_scaled_upper_triang_masked_softmax_forward_ffi"] = + EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForwardHandler); + dict["te_scaled_upper_triang_masked_softmax_backward_ffi"] = + EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackwardHandler); + + // Normalization + dict["te_layernorm_forward_ffi"] = EncapsulateFFI(LayerNormForwardHandler); + dict["te_layernorm_forward_fp8_ffi"] = EncapsulateFFI(LayerNormForwardFP8Handler); + dict["te_layernorm_backward_ffi"] = EncapsulateFFI(LayerNormBackwardHandler); + dict["te_rmsnorm_forward_ffi"] = EncapsulateFunction(RMSNormForwardHandler); + dict["te_rmsnorm_forward_fp8_ffi"] = EncapsulateFunction(RMSNormForwardFP8Handler); + dict["te_rmsnorm_backward_ffi"] = EncapsulateFunction(RMSNormBackwardHandler); + + // Attention + pybind11::dict fused_attn_forward_ffi; + fused_attn_forward_ffi["prepare"] = EncapsulateFFI(CudnnHandleInitHandler); + fused_attn_forward_ffi["execute"] = EncapsulateFFI(FusedAttnForwardHandler); + dict["te_fused_attn_forward_ffi"] = fused_attn_forward_ffi; + + pybind11::dict fused_attn_backward_ffi; + fused_attn_backward_ffi["prepare"] = EncapsulateFFI(CudnnHandleInitHandler); + fused_attn_backward_ffi["execute"] = EncapsulateFFI(FusedAttnBackwardHandler); + dict["te_fused_attn_backward_ffi"] = fused_attn_backward_ffi; + return dict; } diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index ba376c6238..d08368657e 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -6,6 +6,7 @@ #include "extensions.h" #include "transformer_engine/cast.h" +#include "xla/ffi/api/c_api.h" namespace transformer_engine { namespace jax { @@ -27,6 +28,41 @@ void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t op nvte_fp8_quantize(input_tensor.data(), output_tensor.data(), stream); } +Error_Type QuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, + Buffer_Type scale_buf, Buffer_Type scale_inv_buf, Result_Type output_buf, + Result_Type amax_out_buf) { + auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); + + auto *input = input_buf.untyped_data(); + auto *amax = reinterpret_cast(amax_buf.untyped_data()); + auto *scale = reinterpret_cast(scale_buf.untyped_data()); + auto *scale_inv = reinterpret_cast(scale_inv_buf.untyped_data()); + + auto *output = output_buf->untyped_data(); + auto *amax_out = reinterpret_cast(amax_out_buf->untyped_data()); + NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX Quantize primitive."); + + auto input_dims = input_buf.dimensions(); + std::vector shape(input_dims.begin(), input_dims.end()); + auto input_tensor = TensorWrapper(input, shape, in_dtype); + auto output_tensor = TensorWrapper(output, shape, out_dtype, amax_out, scale, scale_inv); + + nvte_fp8_quantize(input_tensor.data(), output_tensor.data(), stream); + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(QuantizeHandler, QuantizeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // amax + .Arg() // scale + .Arg() // scale_inv + .Ret() // output + .Ret(), // amax_out + FFI_CudaGraph_Traits); + void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { auto *input = buffers[0]; auto *amax = reinterpret_cast(buffers[1]); @@ -38,11 +74,41 @@ void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t auto shape = desc.shape.to_vector(); auto input_tensor = TensorWrapper(input, shape, desc.in_dtype, amax, scale, scale_inv); - auto output_tensor = TensorWrapper(output, shape, desc.out_dtype); nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream); } +Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, + Buffer_Type scale_buf, Buffer_Type scale_inv_buf, Result_Type output_buf) { + auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); + + auto *input = input_buf.untyped_data(); + auto *amax = reinterpret_cast(amax_buf.untyped_data()); + auto *scale = reinterpret_cast(scale_buf.untyped_data()); + auto *scale_inv = reinterpret_cast(scale_inv_buf.untyped_data()); + + auto *output = output_buf->untyped_data(); + + auto input_dims = input_buf.dimensions(); + std::vector shape(input_dims.begin(), input_dims.end()); + auto input_tensor = TensorWrapper(input, shape, in_dtype, amax, scale, scale_inv); + auto output_tensor = TensorWrapper(output, shape, out_dtype); + + nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream); + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(DequantizeHandler, DequantizeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // amax + .Arg() // scale + .Arg() // scale_inv + .Ret(), // output + FFI_CudaGraph_Traits); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/softmax.cpp b/transformer_engine/jax/csrc/extensions/softmax.cpp index 3af32d1d84..f54ebefcb0 100644 --- a/transformer_engine/jax/csrc/extensions/softmax.cpp +++ b/transformer_engine/jax/csrc/extensions/softmax.cpp @@ -7,6 +7,7 @@ #include "transformer_engine/softmax.h" #include "extensions.h" +#include "xla/ffi/api/c_api.h" namespace transformer_engine { namespace jax { @@ -108,5 +109,146 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, dgrad_tensor.data(), desc.scale_factor, stream); } +#define SOFTMAX_COMMON_BLOCK(tensor_buf) \ + auto dtype = convert_ffi_datatype_to_te_dtype((tensor_buf).element_type()); \ + auto tensor_dims = (tensor_buf).dimensions(); \ + auto tensor_ranks = tensor_dims.size(); \ + auto batch_size = product(tensor_dims, 0, tensor_ranks - 3); \ + auto head_dim = product(tensor_dims, tensor_ranks - 3, tensor_ranks - 2); \ + auto q_seqlen = product(tensor_dims, tensor_ranks - 2, tensor_ranks - 1); \ + auto k_seqlen = product(tensor_dims, tensor_ranks - 1, tensor_ranks); \ + float scale_factor = static_cast(scale_factor_); + +#define SOFTMAX_FORWARD_COMMON_BLOCK \ + auto *input = input_buf.untyped_data(); \ + auto *output = output_buf->untyped_data(); \ + auto input_tensor = TensorWrapper(input, shape, dtype); \ + auto output_tensor = TensorWrapper(output, shape, dtype); + +Error_Type ScaledSoftmaxForwardFFI(cudaStream_t stream, Buffer_Type input_buf, + Result_Type output_buf, double scale_factor_) { + SOFTMAX_COMMON_BLOCK(input_buf); + auto shape = std::vector{batch_size, head_dim, q_seqlen, k_seqlen}; + SOFTMAX_FORWARD_COMMON_BLOCK; + nvte_scaled_softmax_forward(input_tensor.data(), output_tensor.data(), scale_factor, stream); + return ffi_with_cuda_error_check(); +} + +Error_Type ScaledMaskedSoftmaxForwardFFI(cudaStream_t stream, Buffer_Type input_buf, + Buffer_Type mask_buf, Result_Type output_buf, + double scale_factor_) { + SOFTMAX_COMMON_BLOCK(input_buf); + + // Mask would be casted to uint8_t + auto *mask = mask_buf.untyped_data(); + auto mask_dims = mask_buf.dimensions(); + auto padding_size = product(mask_dims, mask_dims.size() - 3); + auto mask_shape = std::vector{padding_size, 1, q_seqlen, k_seqlen}; + auto mask_tensor = TensorWrapper(mask, mask_shape, DType::kByte); + + auto shape = std::vector{batch_size, head_dim, q_seqlen, k_seqlen}; + SOFTMAX_FORWARD_COMMON_BLOCK; + nvte_scaled_masked_softmax_forward(input_tensor.data(), mask_tensor.data(), output_tensor.data(), + scale_factor, stream); + return ffi_with_cuda_error_check(); +} + +Error_Type ScaledUpperTriangMaskedSoftmaxForwardFFI(cudaStream_t stream, Buffer_Type input_buf, + Result_Type output_buf, double scale_factor_) { + SOFTMAX_COMMON_BLOCK(input_buf); + auto shape = std::vector{batch_size * head_dim, q_seqlen, k_seqlen}; + SOFTMAX_FORWARD_COMMON_BLOCK; + nvte_scaled_upper_triang_masked_softmax_forward(input_tensor.data(), output_tensor.data(), + scale_factor, stream); + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler, ScaledSoftmaxForwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Ret() // output + .Attr("scale_factor"), + FFI_CudaGraph_Traits); + +XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledMaskedSoftmaxForwardHandler, ScaledMaskedSoftmaxForwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // mask + .Ret() // output + .Attr("scale_factor"), + FFI_CudaGraph_Traits); + +XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxForwardHandler, + ScaledUpperTriangMaskedSoftmaxForwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Ret() // output + .Attr("scale_factor"), + FFI_CudaGraph_Traits); + +#define SOFTMAX_BACKWARD_COMMON_BLOCK \ + auto *grad_output = grad_output_buf.untyped_data(); \ + auto *softmax_output = softmax_output_buf.untyped_data(); \ + auto *dgrad = dgrad_buf->untyped_data(); \ + auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype); \ + auto softmax_output_tensor = TensorWrapper(softmax_output, shape, dtype); \ + auto dgrad_tensor = TensorWrapper(dgrad, shape, dtype); + +Error_Type ScaledSoftmaxBackwardFFI(cudaStream_t stream, Buffer_Type grad_output_buf, + Buffer_Type softmax_output_buf, Result_Type dgrad_buf, + double scale_factor_) { + SOFTMAX_COMMON_BLOCK(grad_output_buf); + auto shape = std::vector{batch_size, head_dim, q_seqlen, k_seqlen}; + SOFTMAX_BACKWARD_COMMON_BLOCK; + nvte_scaled_softmax_backward(grad_output_tensor.data(), softmax_output_tensor.data(), + dgrad_tensor.data(), scale_factor, stream); + return ffi_with_cuda_error_check(); +} + +Error_Type ScaledUpperTriangMaskedSoftmaxBackwardFFI(cudaStream_t stream, + Buffer_Type grad_output_buf, + Buffer_Type softmax_output_buf, + Result_Type dgrad_buf, double scale_factor_) { + SOFTMAX_COMMON_BLOCK(grad_output_buf); + auto shape = std::vector{batch_size * head_dim, q_seqlen, k_seqlen}; + SOFTMAX_BACKWARD_COMMON_BLOCK; + nvte_scaled_upper_triang_masked_softmax_backward(grad_output_tensor.data(), + softmax_output_tensor.data(), + dgrad_tensor.data(), scale_factor, stream); + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledSoftmaxBackwardHandler, ScaledSoftmaxBackwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // grad_output + .Arg() // softmax_output + .Ret() // dgrad + .Attr("scale_factor"), + FFI_CudaGraph_Traits); + +// The backward of ScaledMaskedSoftmax is equivalent to ScaledSoftmax +XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledMaskedSoftmaxBackwardHandler, ScaledSoftmaxBackwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // grad_output + .Arg() // softmax_output + .Ret() // dgrad + .Attr("scale_factor"), + FFI_CudaGraph_Traits); + +XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxBackwardHandler, + ScaledUpperTriangMaskedSoftmaxBackwardFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // grad_output + .Arg() // softmax_output + .Ret() // dgrad + .Attr("scale_factor"), + FFI_CudaGraph_Traits); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/transpose.cpp b/transformer_engine/jax/csrc/extensions/transpose.cpp index 7a2e31312a..8480081a68 100644 --- a/transformer_engine/jax/csrc/extensions/transpose.cpp +++ b/transformer_engine/jax/csrc/extensions/transpose.cpp @@ -36,6 +36,37 @@ void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t o TransposeImpl(input, rows, cols, dtype, stream, output); } +Error_Type TransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type output_buf, + int64_t transpose_axis) { + auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); + + void *input = input_buf.untyped_data(); + void *output = output_buf->untyped_data(); + + auto input_dims = input_buf.dimensions(); + if (transpose_axis < 0) transpose_axis += input_dims.size(); + auto m = product(input_dims, 0, transpose_axis); + auto n = product(input_dims, transpose_axis, input_dims.size()); + + auto input_shape = std::vector{m, n}; + auto output_shape = std::vector{n, m}; + + auto input_tensor = TensorWrapper(input, input_shape, in_dtype); + auto output_tensor = TensorWrapper(output, output_shape, out_dtype); + + nvte_transpose(input_tensor.data(), output_tensor.data(), stream); + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(TransposeHandler, TransposeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Ret() // output + .Attr("transpose_axis"), + FFI_CudaGraph_Traits); + void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { auto *input = buffers[0]; float *amax = reinterpret_cast(buffers[1]); @@ -69,20 +100,20 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size Error_Type CastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, Buffer_Type scale_buf, Buffer_Type scale_inv_buf, - Result_Type input_cast_buf, Result_Type input_cast_trans_buf, + Result_Type output_buf, Result_Type output_trans_buf, Result_Type amax_out_buf, int64_t transpose_axis) { auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); - auto out_dtype = convert_ffi_datatype_to_te_dtype(input_cast_buf->element_type()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto *input = input_buf.untyped_data(); float *amax = reinterpret_cast(amax_buf.untyped_data()); float *scale = reinterpret_cast(scale_buf.untyped_data()); float *scale_inv = reinterpret_cast(scale_inv_buf.untyped_data()); - auto *input_cast = input_cast_buf->untyped_data(); - auto *input_cast_trans = input_cast_trans_buf->untyped_data(); + auto *output = output_buf->untyped_data(); + auto *output_trans = output_trans_buf->untyped_data(); float *amax_out = reinterpret_cast(amax_out_buf->untyped_data()); - assert(amax == amax_out); + NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX CastTranspose primitive."); if (!use_fp8(out_dtype)) { scale = nullptr; @@ -92,20 +123,18 @@ Error_Type CastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T auto input_dims = input_buf.dimensions(); if (transpose_axis < 0) transpose_axis += input_dims.size(); - auto m = std::accumulate(input_dims.begin(), input_dims.begin() + transpose_axis, 1, - std::multiplies<>()); - auto n = std::accumulate(input_dims.begin() + transpose_axis, input_dims.end(), 1, - std::multiplies<>()); + auto m = product(input_dims, 0, transpose_axis); + auto n = product(input_dims, transpose_axis, input_dims.size()); auto input_shape = std::vector{m, n}; - auto input_trans_shape = std::vector{n, m}; + auto output_shape = input_shape; + auto output_trans_shape = std::vector{n, m}; auto input_tensor = TensorWrapper(input, input_shape, in_dtype); - auto input_cast_tensor = - TensorWrapper(input_cast, input_shape, out_dtype, amax_out, scale, scale_inv); - auto input_cast_trans_tensor = - TensorWrapper(input_cast_trans, input_trans_shape, out_dtype, amax_out, scale, scale_inv); + auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv); + auto output_trans_tensor = + TensorWrapper(output_trans, output_trans_shape, out_dtype, amax_out, scale, scale_inv); - nvte_cast_transpose(input_tensor.data(), input_cast_tensor.data(), input_cast_trans_tensor.data(), + nvte_cast_transpose(input_tensor.data(), output_tensor.data(), output_trans_tensor.data(), stream); return ffi_with_cuda_error_check(); } @@ -117,10 +146,11 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(CastTransposeHandler, CastTransposeFFI, .Arg() // amax .Arg() // scale .Arg() // scale_inv - .Ret() // input_cast - .Ret() // input_cast_trans + .Ret() // output + .Ret() // output_trans .Ret() // amax_out - .Attr("transpose_axis")); + .Attr("transpose_axis"), + FFI_CudaGraph_Traits); pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype) { @@ -183,5 +213,70 @@ void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, dbias_tensor.data(), workspace.data(), stream); } +Error_Type DBiasCastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, + Buffer_Type scale_buf, Buffer_Type scale_inv_buf, + Result_Type output_buf, Result_Type output_trans_buf, + Result_Type dbias_buf, Result_Type amax_out_buf, + Result_Type workspace_buf, int64_t transpose_axis) { + auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); + auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); + + auto *input = input_buf.untyped_data(); + float *amax = reinterpret_cast(amax_buf.untyped_data()); + float *scale = reinterpret_cast(scale_buf.untyped_data()); + float *scale_inv = reinterpret_cast(scale_inv_buf.untyped_data()); + + auto *output = output_buf->untyped_data(); + auto *output_trans = output_trans_buf->untyped_data(); + auto *dbias = dbias_buf->untyped_data(); + float *amax_out = reinterpret_cast(amax_out_buf->untyped_data()); + void *workspace = workspace_buf->untyped_data(); + NVTE_CHECK(amax == amax_out, + "amax not bound to amax_out in TE/JAX DBiasCastTranspose primitive."); + if (!use_fp8(out_dtype)) { + scale = nullptr; + scale_inv = nullptr; + amax_out = nullptr; + } + + auto input_dims = input_buf.dimensions(); + auto workspace_dims = workspace_buf->dimensions(); + if (transpose_axis < 0) transpose_axis += input_dims.size(); + auto m = product(input_dims, 0, transpose_axis); + auto n = product(input_dims, transpose_axis, input_dims.size()); + auto input_shape = std::vector{m, n}; + auto output_shape = std::vector{m, n}; + auto output_trans_shape = std::vector{n, m}; + auto dbias_shape = std::vector{n}; + std::vector workspace_shape(workspace_dims.begin(), workspace_dims.end()); + + auto input_tensor = TensorWrapper(input, input_shape, in_dtype); + auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv); + auto output_trans_tensor = + TensorWrapper(output_trans, output_trans_shape, out_dtype, amax_out, scale, scale_inv); + auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype); + auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype); + + nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(), output_trans_tensor.data(), + dbias_tensor.data(), workspace_tensor.data(), stream); + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasCastTransposeHandler, DBiasCastTransposeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // amax + .Arg() // scale + .Arg() // scale_inv + .Ret() // output + .Ret() // output_trans + .Ret() // dbias + .Ret() // amax_out + .Ret() // workspace + .Attr("transpose_axis"), + FFI_CudaGraph_Traits); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index b91584219f..cb71188221 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -262,6 +262,8 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me scale_factor: Optional[float] = None transpose_batch_sequence: bool = False window_size: Optional[Tuple[int, int]] = None + context_parallel_causal_load_balanced: bool = False + context_parallel_axis: str = "" @nn.compact def __call__( @@ -308,6 +310,8 @@ def __call__( dropout_probability=self.attention_dropout, is_training=not deterministic, window_size=self.window_size, + context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, + context_parallel_axis=self.context_parallel_axis, ) elif self.qkv_layout == QKVLayout.BSHD_BS2HD: """kvpacked format, treat @@ -331,6 +335,8 @@ def __call__( dropout_probability=self.attention_dropout, is_training=not deterministic, window_size=self.window_size, + context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, + context_parallel_axis=self.context_parallel_axis, ) elif self.qkv_layout == QKVLayout.BSHD_BSHD_BSHD: if self.transpose_batch_sequence: @@ -349,6 +355,8 @@ def __call__( dropout_probability=self.attention_dropout, is_training=not deterministic, window_size=self.window_size, + context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, + context_parallel_axis=self.context_parallel_axis, ) else: raise ValueError(f"Unsupported {self.qkv_layout=}.") @@ -463,6 +471,9 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...). window_size: Optional[Tuple[int, int]], default = None Sliding window size. The default value is no sliding window. + context_parallel_causal_load_balanced (bool): + Indicates the sequences are ordered for causal mask load balancing when running context parallelism. + context_parallel_axis (str): The name of the context parallel axis. Optimization parameters ----------------------- @@ -483,6 +494,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods scale_factor: Optional[float] = None transpose_batch_sequence: bool = True window_size: Optional[Tuple[int, int]] = None + context_parallel_causal_load_balanced: bool = False + context_parallel_axis: str = "" @nn.compact def __call__( @@ -614,6 +627,8 @@ def __call__( transpose_batch_sequence=self.transpose_batch_sequence, qkv_layout=qkv_layout, window_size=self.window_size, + context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, + context_parallel_axis=self.context_parallel_axis, )(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic) return x diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 90504e4c14..bbf0b0f52b 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -516,7 +516,6 @@ def _fused_layernorm_fp8_mlp_bwd_rule( dactivation_lu_scale_inv, bwd_dtype, static_axis_boundary=-1, - transpose_axis_boundary=-2, activation_type=activation_type, ) ) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index a14a8384cf..f2da288be5 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -197,7 +197,7 @@ class MeshResource: The axis name in Mesh used to split the batch and weights along. If it is None, then full-sharded data parallelism is disabled. pp_resource : str, default = None - The axis name in Mesh used to split model layers. along. + The axis name in Mesh used to split model layers along. If it is None, then pipeline parallelism is disabled. cp_resource : str, default = None The axis name in Mesh used to split sequence (context) dimensions along diff --git a/transformer_engine/paddle/csrc/custom_ops.cu b/transformer_engine/paddle/csrc/custom_ops.cu index 904d979b8e..583cd0f47a 100644 --- a/transformer_engine/paddle/csrc/custom_ops.cu +++ b/transformer_engine/paddle/csrc/custom_ops.cu @@ -603,14 +603,14 @@ void UpdateRandomGenerator(phi::Place place, cudaStream_t stream, int rng_elts_p auto state_index = gen_cuda->GetStateIndex(); auto parameterSetter = [gen_cuda, state_index, - rng_elts_per_thread](phi::backends::gpu::CUDAKernelParams ¶ms) { + rng_elts_per_thread](phi::backends::gpu::gpuKernelParams ¶ms) { // ensure the generator use correct state index gen_cuda->SetStateIndex(state_index); auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread); params.As>(1) = seed_offset; }; - phi::backends::gpu::CUDAGraphNodeLauncher::cudaKernelCallback_t cudaKernelCallback = + phi::backends::gpu::CUDAGraphNodeLauncher::gpuKernelCallback_t cudaKernelCallback = [=](unsigned int id) { void *functionPtr = reinterpret_cast(&set_rng_state); cudaFunction_t cudaFunc; @@ -1016,14 +1016,14 @@ void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const p #if PADDLE_VERSION > 261 auto state_index = gen_cuda->GetStateIndex(); auto parameterSetter = [gen_cuda, state_index, - rng_elts_per_thread](phi::backends::gpu::CUDAKernelParams ¶ms) { + rng_elts_per_thread](phi::backends::gpu::gpuKernelParams ¶ms) { // ensure the generator use correct state index gen_cuda->SetStateIndex(state_index); auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread); params.As>(1) = seed_offset; }; - phi::backends::gpu::CUDAGraphNodeLauncher::cudaKernelCallback_t cudaKernelCallback = + phi::backends::gpu::CUDAGraphNodeLauncher::gpuKernelCallback_t cudaKernelCallback = [=](unsigned int id) { void *functionPtr = reinterpret_cast(&set_rng_state); cudaFunction_t cudaFunc; @@ -1383,7 +1383,7 @@ void amax_and_scale_update_inplace_legacy( const int *current_step_id_ptr = reinterpret_cast(GetOptionalDataPtr(current_step_id_tensor)); auto parameterSetter = [current_step_id_ptr, - fwd_update](phi::backends::gpu::CUDAKernelParams ¶ms) { + fwd_update](phi::backends::gpu::gpuKernelParams ¶ms) { if (fwd_update) { int current_step_id = *current_step_id_ptr; params.As(7) = (current_step_id == 0); @@ -1397,7 +1397,7 @@ void amax_and_scale_update_inplace_legacy( float *scale_ptr = scale.data(); float *scale_inv_ptr = scale_inv.data(); - phi::backends::gpu::CUDAGraphNodeLauncher::cudaKernelCallback_t cudaKernelCallback = + phi::backends::gpu::CUDAGraphNodeLauncher::gpuKernelCallback_t cudaKernelCallback = [=](unsigned int id) { void *functionPtr = reinterpret_cast(&UpdateFP8MetaKernel); cudaFunction_t cudaFunc; diff --git a/transformer_engine/paddle/csrc/extensions.cu b/transformer_engine/paddle/csrc/extensions.cpp similarity index 100% rename from transformer_engine/paddle/csrc/extensions.cu rename to transformer_engine/paddle/csrc/extensions.cpp diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index c4097333d3..781f9d42fd 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -82,6 +82,7 @@ def _load_library(): from transformer_engine.pytorch.distributed import checkpoint from transformer_engine.pytorch.distributed import CudaRNGStatesTracker from transformer_engine.pytorch.cpu_offload import get_cpu_offload_context +from transformer_engine.pytorch import ops from transformer_engine.pytorch import optimizers # Register custom op symbolic ONNX functions diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index be36b0375a..8159f20e90 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1729,17 +1729,20 @@ def forward( fused_attn_qkv_dtype = None fused_attn_backend = None amax_per_step = None + qkv_dtype = q.dtype + # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype + is_input_fp8 = False + is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha if fp8: if use_fused_attention: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fused_attn_qkv_dtype = fp8_dtype_forward fused_attn_backend = FusedAttnBackend["FP8"] - if fp8_meta["recipe"].fp8_mha: - assert ( - isinstance(q, Float8Tensor) - and isinstance(k, Float8Tensor) - and isinstance(v, Float8Tensor) - ), "q/k/v must be Float8Tensors for FP8 MHA!" + assert isinstance(k, q.__class__) and isinstance( + v, q.__class__ + ), "q, k, and v must have the same type." + is_input_fp8 = isinstance(q, Float8Tensor) + if is_input_fp8: fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv q_fp8, k_fp8, v_fp8 = q, k, v q, k, v = q_fp8._data, k_fp8._data, v_fp8._data @@ -1778,7 +1781,7 @@ def forward( ) if not fp8: q_f16 = q - elif not fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + elif not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): q_f16 = q q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) @@ -1880,11 +1883,7 @@ def forward( batch_p2p_comm, ) - if ( - not fp8 - or fp8_meta["recipe"].fp8_mha - or int(os.getenv("NVTE_FP8_DPA_BWD", "1")) - ): + if not fp8 or is_input_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")): kv_inputs[i % 2] = p2p_comm_buffers[i] else: # KV exchange is in BF16/FP16, cast received KV in each step @@ -2436,18 +2435,18 @@ def forward( fp8_meta["scaling_fwd"].amax_history[0][META_O_CP] = amax_cp_fwd[1] out_fp8 = None - out_f16 = out.to(q_fp8.dtype if fp8 and fp8_meta["recipe"].fp8_mha else q_f16.dtype) - if fp8 and (fp8_meta["recipe"].fp8_mha or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))): + out_f16 = out.to(qkv_dtype) + if fp8 and (is_output_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))): out_fp8 = cast_to_fp8(out_f16, fp8_meta["scaling_fwd"], META_O, fp8_dtype_forward) - if fp8 and fp8_meta["recipe"].fp8_mha: + if fp8 and is_output_fp8: out_ret = Float8Tensor( data=out_fp8, fp8_meta=fp8_meta, fp8_meta_forward=True, fp8_meta_index=META_O, fp8_dtype=fp8_dtype_forward, - dtype=q_fp8.dtype, + dtype=qkv_dtype, ) else: out_ret = out_f16 @@ -2456,7 +2455,7 @@ def forward( q_save, kv_save, out_save = q, kv, out_fp8 fp8_fwd_scales = fp8_meta["scaling_fwd"].scale.clone() fp8_fwd_scale_invs = fp8_meta["scaling_fwd"].scale_inv.clone() - elif fp8 and fp8_meta["recipe"].fp8_mha: + elif fp8 and is_input_fp8: q_fp8 = Float8Tensor( data=q, fp8_meta=fp8_meta, @@ -2513,6 +2512,8 @@ def forward( ctx.use_fused_attention = use_fused_attention ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) ctx.fp8_meta = fp8_meta + ctx.is_input_fp8 = is_input_fp8 + ctx.is_output_fp8 = is_output_fp8 return out_ret @staticmethod @@ -2527,12 +2528,13 @@ def backward(ctx, dout): recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) - (q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[:6] - (fp8_fwd_scales, fp8_fwd_scale_invs) = ctx.saved_tensors[6:8] - cu_seqlens_q_per_step = ctx.saved_tensors[8 : 8 + cp_size] - cu_seqlens_kv_per_step = ctx.saved_tensors[8 + cp_size : 8 + cp_size * 2] - rng_states = ctx.saved_tensors[8 + cp_size * 2 : 8 + cp_size * 3] - attn_biases = ctx.saved_tensors[8 + cp_size * 3 : 8 + cp_size * 4] + (*saved_tensors,) = ctx.saved_tensors + (q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = saved_tensors[:6] + (fp8_fwd_scales, fp8_fwd_scale_invs) = saved_tensors[6:8] + cu_seqlens_q_per_step = saved_tensors[8 : 8 + cp_size] + cu_seqlens_kv_per_step = saved_tensors[8 + cp_size : 8 + cp_size * 2] + rng_states = saved_tensors[8 + cp_size * 2 : 8 + cp_size * 3] + attn_biases = saved_tensors[8 + cp_size * 3 : 8 + cp_size * 4] causal = "causal" in ctx.attn_mask_type padding = "padding" in ctx.attn_mask_type @@ -2595,7 +2597,7 @@ def backward(ctx, dout): dq_fp8 = torch.empty((cp_size, *q.shape), dtype=q.dtype, device=q.device) dkv_fp8 = torch.empty((cp_size, *kv.shape), dtype=kv.dtype, device=kv.device) dkv_fp8_ = torch.empty_like(dkv_fp8) - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_output_fp8: assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv dout = dout._data @@ -2617,7 +2619,7 @@ def backward(ctx, dout): else: assert False, "FP8 is only supported with Fused Attention!" else: - if ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha: + if ctx.fp8_meta is not None and ctx.is_input_fp8: q, kv = [x.from_float8(x.dtype) for x in [q, kv]] if cp_size_a2a == 1: dout = dout.from_float8(dout_dtype) @@ -2653,7 +2655,7 @@ def backward(ctx, dout): ctx.cp_stream, True, ) - if not ctx.fp8 and ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha: + if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8: dout = cast_from_fp8( dout, None, @@ -3260,7 +3262,7 @@ def backward(ctx, dout): dkv_[:, cu_seqlens_kv_padded[-1] :].fill_(0) dkv = dkv_ - if ctx.fp8 and ctx.fp8_meta["recipe"].fp8_mha: + if ctx.fp8 and ctx.is_input_fp8: dq, dkv = [ cast_to_fp8(x, ctx.fp8_meta["scaling_bwd"], META_DQKV, fp8_dtype_backward) for x in [dq, dkv] @@ -3283,7 +3285,7 @@ def backward(ctx, dout): elif ctx.qkv_format == "sbhd": dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] - if ctx.fp8 and ctx.fp8_meta["recipe"].fp8_mha: + if ctx.fp8 and ctx.is_input_fp8: dq, dk, dv = [ Float8Tensor( data=x, @@ -3576,11 +3578,12 @@ def backward(ctx, dout): cp_size = get_distributed_world_size(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group) - (q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = ctx.saved_tensors[:5] - cu_seqlens_kv_per_step = ctx.saved_tensors[5:7] - out_per_step = ctx.saved_tensors[7:9] - softmax_lse_per_step = ctx.saved_tensors[9:11] - rng_states = ctx.saved_tensors[11:13] + (*saved_tensors,) = ctx.saved_tensors + (q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = saved_tensors[:5] + cu_seqlens_kv_per_step = saved_tensors[5:7] + out_per_step = saved_tensors[7:9] + softmax_lse_per_step = saved_tensors[9:11] + rng_states = saved_tensors[11:13] kv_seq_range_per_step = ctx.kv_seq_range_per_step window_size_per_step = ctx.window_size_per_step @@ -3852,19 +3855,22 @@ def forward( q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 ), "Sequence length per GPU needs to be divisible by 2!" + qkv_dtype = q.dtype fused_attn_backend = None fused_attn_qkv_dtype = None + # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype + is_input_fp8 = False + is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha if fp8: if use_fused_attention: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fused_attn_qkv_dtype = fp8_dtype_forward fused_attn_backend = FusedAttnBackend["FP8"] - if fp8_meta["recipe"].fp8_mha: - assert ( - isinstance(q, Float8Tensor) - and isinstance(k, Float8Tensor) - and isinstance(v, Float8Tensor) - ), "q/k/v must be Float8Tensors for FP8 MHA!" + assert isinstance(k, q.__class__) and isinstance( + v, q.__class__ + ), "q, k, and v must have the same type." + is_input_fp8 = isinstance(q, Float8Tensor) + if is_input_fp8: fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv q_fp8, k_fp8, v_fp8 = q, k, v q, k, v = q_fp8._data, k_fp8._data, v_fp8._data @@ -3900,7 +3906,7 @@ def forward( [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True ) - if fp8 and not fp8_meta["recipe"].fp8_mha and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + if fp8 and not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): q_f16, k_f16, v_f16 = q, k, v q, k, v = [ cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) @@ -3965,14 +3971,14 @@ def forward( out = out.view(-1, batch_size, *out.shape[-2:]) if fp8: - if fp8_meta["recipe"].fp8_mha: + if is_output_fp8: out_fp8 = Float8Tensor( data=out, fp8_meta=fp8_meta, fp8_meta_forward=True, fp8_meta_index=META_O, fp8_dtype=fp8_dtype_forward, - dtype=q_fp8.dtype, + dtype=qkv_dtype, ) out = out_fp8._data out_ret = out_fp8 @@ -3991,7 +3997,7 @@ def forward( if fp8: if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): q_save, k_save, v_save, out_save = q, k, v, out - elif fp8_meta["recipe"].fp8_mha: + elif is_input_fp8: q_fp8, k_fp8, v_fp8 = [ Float8Tensor( data=x, @@ -4043,6 +4049,8 @@ def forward( ctx.use_fused_attention = use_fused_attention ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) ctx.fp8_meta = fp8_meta + ctx.is_input_fp8 = is_input_fp8 + ctx.is_output_fp8 = is_output_fp8 return out_ret @staticmethod @@ -4050,12 +4058,11 @@ def backward(ctx, dout): # pylint: disable=missing-function-docstring cp_size = get_distributed_world_size(ctx.cp_group) - q, k, v, out = ctx.saved_tensors[:4] - cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded = ctx.saved_tensors[ - 4:8 - ] - fp8_fwd_scales, fp8_fwd_scale_invs = ctx.saved_tensors[8:10] - aux_ctx_tensors = ctx.saved_tensors[10:] + (*saved_tensors,) = ctx.saved_tensors + q, k, v, out = saved_tensors[:4] + cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded = saved_tensors[4:8] + fp8_fwd_scales, fp8_fwd_scale_invs = saved_tensors[8:10] + aux_ctx_tensors = saved_tensors[10:] qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format causal = "causal" in ctx.attn_mask_type @@ -4064,6 +4071,7 @@ def backward(ctx, dout): fused_attn_backend = None fused_attn_dqkv_dtype = None fused_attn_qkv_dtype = None + dout_dtype = dout.dtype if ctx.fp8: if ctx.use_fused_attention: fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) @@ -4071,7 +4079,7 @@ def backward(ctx, dout): fused_attn_qkv_dtype = fp8_dtype_forward fused_attn_dqkv_dtype = fp8_dtype_backward fused_attn_backend = FusedAttnBackend["FP8"] - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_output_fp8: assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv dout_fp8 = dout @@ -4097,7 +4105,7 @@ def backward(ctx, dout): else: assert False, "FP8 is only supported with Fused Attention!" else: - if ctx.fp8_meta is not None and ctx.fp8_meta["recipe"].fp8_mha: + if ctx.fp8_meta is not None and ctx.is_output_fp8: assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" q, k, v, out, dout = [x.from_float8(x.dtype) for x in [q, k, v, out, dout]] if ctx.use_fused_attention: @@ -4194,7 +4202,7 @@ def backward(ctx, dout): dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] if ctx.fp8: - if ctx.fp8_meta["recipe"].fp8_mha: + if ctx.is_input_fp8: dq, dk, dv = [ Float8Tensor( data=x, @@ -4202,7 +4210,7 @@ def backward(ctx, dout): fp8_meta_forward=False, fp8_meta_index=META_DQKV, fp8_dtype=fp8_dtype_backward, - dtype=dout_fp8.dtype, + dtype=dout_dtype, ) for x in [dq, dk, dv] ] @@ -4213,7 +4221,7 @@ def backward(ctx, dout): ctx.fp8_meta["scaling_bwd"], META_DQKV, fp8_dtype_backward, - TE_DType[dout_f16.dtype], + TE_DType[dout_dtype], ) for x in [dq, dk, dv] ] @@ -5434,11 +5442,12 @@ def convert_to_torch_float8(tensor, dtype): ) return out - if fp8_meta["recipe"].fp8_mha: - assert all( - isinstance(x, Float8Tensor) - for x in [query_layer, key_layer, value_layer] - ), "q/k/v must be Float8Tensors for FP8 MHA." + # "fp8_mha" decides outputs in fp8, while inputs are inferred from + # the real dtype + assert isinstance(key_layer, query_layer.__class__) and isinstance( + value_layer, query_layer.__class__ + ), "q, k, and v must have the same type." + if isinstance(query_layer, Float8Tensor): fp8_meta["scaling_fwd"].scale_inv[META_QKV] = query_layer._scale_inv else: query_layer, key_layer, value_layer = ( @@ -5580,6 +5589,7 @@ def forward( deterministic, ): # pylint: disable=missing-function-docstring + # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype is_input_fp8 = False is_output_fp8 = fp8_meta["recipe"].fp8_mha if fp8: @@ -5970,6 +5980,7 @@ def forward( deterministic, ): # pylint: disable=missing-function-docstring + # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype is_input_fp8 = False is_output_fp8 = fp8_meta["recipe"].fp8_mha if fp8: @@ -6424,6 +6435,7 @@ def forward( deterministic, ): # pylint: disable=missing-function-docstring + # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype is_input_fp8 = False is_output_fp8 = fp8_meta["recipe"].fp8_mha if fp8: @@ -7673,6 +7685,60 @@ def forward( based on its internal logic. These optimizations trade memory for performance and should be used with care. + .. note:: + .. _cu_seqlens note: + + When training data has variable sequence lengths, users have two options. + + 1. Manipulate the data and pad all sequences to the same length. Use + :attr:`qkv_format` = {"bshd", "sbhd"} and + :attr:`attn_mask_type` = {"padding", "padding_causal", "padding_causal_bottom_right"}. + Pass in :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`, or :attr:`attention_mask` + (which will be converted to :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`), to provide + the real sequence length information. For example, a batch of 3 sequences + [a a a b b c c c c] can be padded to [a a a PAD b b PAD PAD c c c c], and the cumulative + sequence length tensors would be + :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9] for self-attention. + + 2. Do not perform padding on training data. Use :attr:`qkv_format` = "thd" and + :attr:`attn_mask_type` = {"padding", "padding_causal", "padding_causal_bottom_right"}. + Pass in :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`, or :attr:`attention_mask`, + as in option 1. For example, a batch of 3 sequences [a a a b b c c c c] can be processed + without any padding, and the sequence length tensors would be + :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9] for self-attention. + + In certain use cases, a varying number of identifier tokens are inserted between + sequences. These tokens do not participate in the attention calculation. + :attr:`cu_seqlens_q_padded` and :attr:`cu_seqlens_kv_padded` must be specified + in such cases to correctly identify the start and end of each sequence in a batch. + For example, a batch of 3 sequences [a a a 1 b b 2 2 c c c c 3] would have + :attr:`cu_seqlens_q` = :attr:`cu_seqlens_kv` = [0, 3, 5, 9], and + :attr:`cu_seqlens_q_padded` = :attr:`cu_seqlens_kv_padded` = [0, 4, 8, 13] + for self-attention. + + .. note:: + .. _max_seqlen note: + + When :attr:`qkv_format` = {"bshd", "sbhd"}, sequences are of equal length in a batch. + :attr:`max_seqlen_q` and :attr:`max_seqlen_kv` should be the same as the "s" dimension of + :attr:`query_layer` and :attr:`key_layer` tensors. When unset, Transformer Engine will + infer them as such. + + When :attr:`qkv_format` = "thd", sequences have varying lengths. :attr:`max_seqlen_q` and + :attr:`max_seqlen_kv` should be the maximum query and key/value sequence length in a batch. + When unset, Transformer Engine deduces them from :attr:`cu_seqlens_q` and :attr:`cu_seqlens_kv`. + This deduction costs a small kernel and some CPU-GPU synchronization, and to avoid this + overhead, users are recommended to obtain the maximum sequence lengths from the data loaders + and pass them in. + + - As the maximum sequence lengths, batch size, and number of tokens change from batch to batch, + dynamic shapes need to be supported for tensor construction. FlashAttention and + UnfusedDotProductAttention naturally do so, while FusedAttention requires parameters to be static + to create graphs before performance heuristics analysis. To reduce the number of graphs created + per run, Transformer Engine 1.13+ quantizes relevant parameters: for cuDNN < 9.6, {batch size, + :attr:`max_seqlen_q`, :attr:`max_seqlen_kv`}, and for cuDNN >= 9.6, {"t" dimension of + :attr:`query_layer`, "t" dimension of :attr:`key_layer`}. + Parameters ---------- query_layer : torch.Tensor @@ -7695,25 +7761,29 @@ def forward( cu_seqlens_q: Optional[torch.Tensor], default = `None` Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`, with shape [batch_size + 1] and dtype torch.int32. + See :ref:`note` for more details. cu_seqlens_kv: Optional[torch.Tensor], default = `None` Cumulative sum of sequence lengths (without offset) in a batch for `key_layer` and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. + See :ref:`note` for more details. cu_seqlens_q_padded: Optional[torch.Tensor], default = `None` Cumulative sum of sequence lengths (with offset) in a batch for `query_layer`, with shape [batch_size + 1] and dtype torch.int32. When there is no padding between sequences in a batch, `cu_seqlens_q_padded = cu_seqlens_q`. + See :ref:`note` for more details. cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None` Cumulative sum of sequence lengths (with offset) in a batch for `key_layer` and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. When there is no padding between sequences in a batch, `cu_seqlens_kv_padded = cu_seqlens_kv`. + See :ref:`note` for more details. max_seqlen_q: Optional[int], default = `None` Maximum sequence length in `query_layer`. - Calculated from `cu_seqlens_q` if not provided. + See :ref:`note` for more details. max_seqlen_kv: Optional[int], default = `None` Maximum sequence length in `key_layer` and `value_layer`. - Calculated from `cu_seqlens_kv` if not provided. + See :ref:`note` for more details. attn_mask_type: {'no_mask', 'padding', 'causal', 'padding,causal', 'causal,padding', 'padding_causal', 'causal_bottom_right', 'padding_causal_bottom_right', 'arbitrary'}, default = `None`. Type of attention mask passed into @@ -7882,7 +7952,10 @@ def forward( assert ( key_layer.shape[-2] == self.num_gqa_groups_per_partition and value_layer.shape[-2] == self.num_gqa_groups_per_partition - ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!" + ), ( + "Keys and values must have num_gqa_group =" + f" {self.num_gqa_groups_per_partition} heads!" + ) assert qkv_format in [ "sbhd", "bshd", @@ -7904,6 +7977,7 @@ def forward( assert ( cu_seqlens_q.dtype == torch.int32 and cu_seqlens_kv.dtype == torch.int32 ), "cu_seqlens_q and cu_seqlens_q must both be in dtype torch.int32!" + batch_size = len(cu_seqlens_q) - 1 if max_seqlen_q is None: if cu_seqlens_q_padded is not None: seqlens_q = cu_seqlens_q_padded[1:] - cu_seqlens_q_padded[:-1] @@ -7916,7 +7990,6 @@ def forward( else: seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64) - batch_size = len(cu_seqlens_q) - 1 cp_size = 1 if isinstance(self.cp_group, dist_group_type): @@ -7931,10 +8004,12 @@ def forward( len(x.shape) == 4 for x in (query_layer, key_layer, value_layer) ), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!" if qkv_format == "sbhd": - max_seqlen_q, max_seqlen_kv = (query_layer.shape[0], key_layer.shape[0]) + max_seqlen_q = query_layer.shape[0] if max_seqlen_q is None else max_seqlen_q + max_seqlen_kv = key_layer.shape[0] if max_seqlen_kv is None else max_seqlen_kv batch_size = query_layer.shape[1] else: - max_seqlen_q, max_seqlen_kv = (query_layer.shape[1], key_layer.shape[1]) + max_seqlen_q = query_layer.shape[1] if max_seqlen_q is None else max_seqlen_q + max_seqlen_kv = key_layer.shape[1] if max_seqlen_kv is None else max_seqlen_kv batch_size = query_layer.shape[0] max_seqlen_q *= cp_size max_seqlen_kv *= cp_size @@ -7943,13 +8018,13 @@ def forward( assert all( seqlens_q <= max_seqlen_q ), """Sequence lengths indicated by cu_seqlens_q must be no greater than - the sequence dimention in 'query_layer'!""" + the sequence dimension in 'query_layer'!""" if cu_seqlens_kv is not None: seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] assert all( seqlens_kv <= max_seqlen_kv ), """Sequence lengths indicated by cu_seqlens_kv must be no greater than - the sequence dimention in 'key_layer' and 'value_layer'!""" + the sequence dimension in 'key_layer' and 'value_layer'!""" if cu_seqlens_q is None or cu_seqlens_kv is None: if "padding" in attn_mask_type: assert ( diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index fd1eb4a810..932bb3cafa 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -45,8 +45,8 @@ def fp8_gemm( use_bias: bool = False, use_split_accumulator: bool = False, D_dtype: Optional[tex.DType] = None, - ub_algo: tex.UbufOverlapAlgo = None, - ub: Union[tex.UbufCommOverlap, tex.UbufP2PCommOverlap] = None, + ub_algo: tex.CommOverlapAlgo = None, + ub: Union[tex.CommOverlap, tex.CommOverlapP2P] = None, extra_output_tensor: torch.Tensor = None, ) -> torch.Tensor: """TN layout GEMM with fp8 inputs.""" @@ -107,7 +107,7 @@ def fp8_gemm( fn = torch.ops.tex_ts.te_gemm_ts if ub_algo is not None: assert ub is not None, "ub object is None!" - if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG: + if ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_AG: fn = ub.bulk_overlap extra_output_tensor = ( empty_tensor if extra_output_tensor is None else extra_output_tensor @@ -115,11 +115,11 @@ def fp8_gemm( args = tuple( args + ( - 1, + tex.CommOverlapType.AG, extra_output_tensor, ) ) - elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS: + elif ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_RS: fn = ub.bulk_overlap extra_output_tensor = ( empty_tensor if extra_output_tensor is None else extra_output_tensor @@ -127,23 +127,23 @@ def fp8_gemm( args = tuple( args + ( - 0, + tex.CommOverlapType.RS, extra_output_tensor, ) ) - elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P: + elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P: fn = ub.split_overlap_ag_p2p extra_output_tensor = ( empty_tensor if extra_output_tensor is None else extra_output_tensor ) args = tuple(args + (extra_output_tensor,)) - elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P: + elif ub_algo == tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P: fn = ub.atomic_gemm_overlap_ag_p2p extra_output_tensor = ( empty_tensor if extra_output_tensor is None else extra_output_tensor ) args = tuple(args + (extra_output_tensor,)) - elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS: + elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS: fn = ub.split_overlap_rs assert ( extra_output_tensor is not None @@ -155,13 +155,13 @@ def fp8_gemm( extra_output_tensor, ) ) - elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P: + elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P: fn = ub.split_overlap_rs_p2p assert ( extra_output_tensor is not None ), "SPLIT_PIPELINED_RS_P2P requires extra output tensor" args = tuple(args + (extra_output_tensor,)) - elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS: + elif ub_algo == tex.CommOverlapAlgo.ATOMIC_GEMM_RS: fn = ub.atomic_gemm_overlap_rs assert extra_output_tensor is not None, "ATOMIC_GEMM_RS requires extra output tensor" args = tuple( @@ -171,16 +171,13 @@ def fp8_gemm( extra_output_tensor, ) ) - elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P: + elif ub_algo == tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P: fn = ub.atomic_gemm_overlap_rs_p2p assert ( extra_output_tensor is not None ), "ATOMIC_GEMM_RS_P2P requires extra output tensor" args = tuple(args + (extra_output_tensor,)) - if ub_algo is not None and ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P: - out = fn(*args) - else: - _ = fn(*args) + _ = fn(*args) return out, gelu_input @@ -198,8 +195,8 @@ def gemm( out: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, use_bias: bool = False, - ub_algo: tex.UbufOverlapAlgo = None, - ub: tex.UbufCommOverlap = None, + ub_algo: tex.CommOverlapAlgo = None, + ub: Union[tex.CommOverlap, tex.CommOverlapP2P] = None, extra_output_tensor: torch.Tensor = None, ) -> Tuple[Union[torch.Tensor, None], ...]: """Non FP8 GEMM.""" @@ -270,19 +267,19 @@ def gemm( fn = torch.ops.tex_ts.te_gemm_ts if ub_algo is not None: assert ub is not None, "ub object is None!" - if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG: + if ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_AG: fn = ub.bulk_overlap - args = tuple(args + (1, empty_tensor)) - elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS: + args = tuple(args + (tex.CommOverlapType.AG, empty_tensor)) + elif ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_RS: fn = ub.bulk_overlap - args = tuple(args + (0, empty_tensor)) - elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P: + args = tuple(args + (tex.CommOverlapType.RS, empty_tensor)) + elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P: fn = ub.split_overlap_ag_p2p extra_output_tensor = ( empty_tensor if extra_output_tensor is None else extra_output_tensor ) args = tuple(args + (extra_output_tensor,)) - elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS: + elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS: fn = ub.split_overlap_rs assert ( extra_output_tensor is not None @@ -294,7 +291,7 @@ def gemm( extra_output_tensor, ) ) - elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P: + elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P: fn = ub.split_overlap_rs_p2p assert ( extra_output_tensor is not None diff --git a/transformer_engine/pytorch/cpp_extensions/transpose.py b/transformer_engine/pytorch/cpp_extensions/transpose.py index ddc3b67e9e..188c03b27c 100644 --- a/transformer_engine/pytorch/cpp_extensions/transpose.py +++ b/transformer_engine/pytorch/cpp_extensions/transpose.py @@ -16,6 +16,7 @@ "fp8_cast_transpose_fused", "fp8_cast_transpose_bgrad_fused", "fp8_cast_transpose_bgrad_dgelu_fused", + "fp8_dswiglu_cast_transpose_fused", "fp8_multi_cast_transpose_fused", "fp8_transpose_bgrad_fused", ] @@ -168,6 +169,44 @@ def fp8_cast_transpose_bgrad_dgelu_fused( ) +def fp8_dswiglu_cast_transpose_fused( + grad_output: torch.Tensor, + inp: torch.Tensor, + *, + grad_input: torch.Tensor, + grad_input_transpose: torch.Tensor, + otype: tex.DType, + fp8_meta: Optional[tex.FP8TensorMeta] = None, + fp8_meta_index: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None] = None, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, +) -> None: + """Fused SwiGLU backward + FP8 cast + FP8 transpose""" + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta, + fp8_meta_index=fp8_meta_index, + ) + + # Launch kernel + return tex.fused_dswiglu_cast_transpose( + grad_output, + inp, + grad_input, + grad_input_transpose, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + otype, + **fp8_scales_offsets, + ) + + def fp8_multi_cast_transpose_fused( input_list: List[torch.Tensor], fp8_meta_tensor: tex.FP8TensorMeta, diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h deleted file mode 100644 index 3b4e126943..0000000000 --- a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h +++ /dev/null @@ -1,1303 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_COMM_GEMM_OVERLAP_H_ -#define TRANSFORMER_ENGINE_PYTORCH_CSRC_COMM_GEMM_OVERLAP_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include "common/common.h" -#include "common/util/cuda_driver.h" -#include "common/util/logging.h" -#include "common/util/system.h" -#include "extensions.h" -#include "userbuffers/userbuffers.h" - -#define HALF_BYTES 2 -#define UB_MAX_SM 32 - -using namespace torch::indexing; -using namespace std::placeholders; - -namespace ubuf { - -bool device_supports_multicast() { - int dev, supports_multicast; - CUdevice cudev; - - NVTE_CHECK_CUDA(cudaGetDevice(&dev)); - NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &cudev, dev); - NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGetAttribute, &supports_multicast, - CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, cudev); - - return static_cast(supports_multicast); -} - -bool ubuf_built_with_mpi() { -#ifdef NVTE_UB_WITH_MPI - return true; -#else - return false; -#endif -} - -class UbufBootstrapCallbacks : torch::CustomClassHolder { - private: - bool initialized{false}; - bool backend_is_nccl{false}; - std::map pgs; - - public: - UbufBootstrapCallbacks() { -#ifndef NVTE_UB_WITH_MPI - NVTE_ERROR("Internal TE error: Dummy UbufBootstrapCallbacks init without NVTE_UB_WITH_MPI=1!"); -#endif - } // empty constructor for NVTE_UB_WITH_MPI=1 - - UbufBootstrapCallbacks(c10d::ProcessGroup *world_group, c10d::ProcessGroup *intra_node_group) { - pgs.insert({"world", world_group}); - c10d::ProcessGroup::BackendType backend = world_group->getBackendType(); - backend_is_nccl = (backend == c10d::ProcessGroup::BackendType::NCCL); - - NVTE_CHECK(intra_node_group->getBackendType() == backend, - "Internal TE error: Intra-node group must be on the same backend (%s) as the world ", - "group!", world_group->getBackendName()); - pgs.insert({"intra", intra_node_group}); - - initialized = true; - } - - ~UbufBootstrapCallbacks() { - for (auto &pg : pgs) pg.second = nullptr; - backend_is_nccl = false; - initialized = false; - } - - void ub_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes, - char *group) { - NVTE_CHECK(initialized, "Internal TE error: tex.UbufBootstrapCallbacks() is not initialized ", - "with valid process groups!"); - - auto localtensor = - torch::from_blob(localdata, {static_cast(localbytes / sizeof(uint8_t))}, - at::device(torch::kCPU).dtype(torch::kUInt8)); - auto localtmp = (backend_is_nccl) ? localtensor.cuda() : localtensor; - auto globaltensor = - torch::from_blob(globaldata, {static_cast(globalbytes / sizeof(uint8_t))}, - at::device(torch::kCPU).dtype(torch::kUInt8)); - auto globaltmp = (backend_is_nccl) ? globaltensor.cuda() : globaltensor; - - std::vector> globalchunks = {globaltmp.chunk(pgs[group]->getSize())}; - std::vector localchunk = {localtmp}; - auto work = pgs[group]->allgather(globalchunks, localchunk); - work->wait(); - - if (backend_is_nccl) { - globaltensor.copy_(globaltmp.cpu()); - globaltmp = torch::Tensor(); - localtmp = torch::Tensor(); - } - } - - void ub_barrier(char *group) { - NVTE_CHECK(initialized, "Internal TE error: tex.UbufBootstrapCallbacks() is not initialized ", - "with valid process groups!"); - auto work = pgs[group]->barrier(); - work->wait(); - } -}; - -enum class COMM_TYPE { RS = 0, AG = 1 }; - -enum class UBOverlapAlgo { - BULK_OVERLAP_AG = 0, - BULK_OVERLAP_RS = 1, - SPLIT_PIPELINED_AG_P2P = 2, - SPLIT_PIPELINED_RS = 3, - SPLIT_PIPELINED_RS_P2P = 4, - ATOMIC_GEMM_RS = 5, - ATOMIC_GEMM_AG_P2P = 6, - ATOMIC_GEMM_RS_P2P = 7 -}; - -struct UbufBase { - static inline communicator *_ub_comm{nullptr}; - static inline bool comm_created{false}; -}; -struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { - int _tp_id; - int _tp_size; - int _num_splits; - int _math_sms; - int _ub_reg; - void *_ubuf_ptr; - torch::Tensor _ubuf; - torch::Tensor output_tensor; - torch::Tensor _ubuf_scale_inv; - bool _ubuf_scale_inv_initialized; - torch::Tensor counter; - at::cuda::CUDAStream _stream_comm = at::cuda::getStreamFromPool(true); - std::vector _stream_compute; - cudaEvent_t _start_compute, _stop_compute, _start_d2dcopy, _start_comm, _stop_comm; - int _num_comm_sm; - int _cga_size; - int _use_ce; - bool _atomic_gemm; - - UbufCommOverlap(torch::Tensor sample, int myrank, int numranks, int mylocal, int numlocal, - int mynode, int numnodes, int tp_size, int num_comm_sm, int comm_cga_size, - int num_splits, bool set_sm_margin, int num_max_streams, bool atomic_gemm, - UbufBootstrapCallbacks &callbacks) { - // Initialize userbuf communicator - if (!comm_created) { - if (myrank == 0) { - printf("!!! [UB] Create Userbuffers Communicator\n"); - } -#ifdef NVTE_UB_WITH_MPI - create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1); -#else - create_communicator_grouped2( - &_ub_comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, - std::bind(&UbufBootstrapCallbacks::ub_allgather, callbacks, _1, _2, _3, _4, _5), - std::bind(&UbufBootstrapCallbacks::ub_barrier, callbacks, _1), 1, 1, tp_size, 1); -#endif - comm_created = true; - } - _use_ce = 0; - _num_comm_sm = num_comm_sm; - _cga_size = comm_cga_size; - - // Allocate and register extra userbuffers - int ubuf_bytes = sample.numel() * sample.element_size(); - _ub_reg = register_user_buffer_collective(reinterpret_cast(&_ubuf_ptr), ubuf_bytes, - _ub_comm, true); - _ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options()); - - if (_ub_comm->myrank == 0) { - printf("!!! [UB] Register UBuf %d\n", _ub_reg); - } - - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - for (int i = 0; i < std::min(num_max_streams, num_splits); i++) { - cudaStream_t stream; - cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, -1); - _stream_compute.push_back( - at::cuda::getStreamFromExternal(stream, stream_main.device_index())); - } - - _num_splits = num_splits; - _tp_size = tp_size; - _tp_id = (_ub_comm->myrank % _tp_size); - _ubuf_scale_inv_initialized = false; - - // Set the number of SMs for GEMM with margin - cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, 0); - _math_sms = (set_sm_margin) ? prop.multiProcessorCount - num_comm_sm : prop.multiProcessorCount; - _math_sms -= transformer_engine::getenv("NVTE_EXT_MARGIN_SM", 0); - - output_tensor = torch::Tensor(); - _atomic_gemm = atomic_gemm; - if (_atomic_gemm) { - auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA); - counter = torch::zeros({num_splits * 2}, counter_options); - counter.index_put_({Slice(None, num_splits)}, 1); - } - // CUDA event creation - cudaEventCreateWithFlags(&_start_compute, 0); - cudaEventCreateWithFlags(&_stop_compute, 0); - cudaEventCreateWithFlags(&_start_d2dcopy, 0); - cudaEventCreateWithFlags(&_start_comm, 0); - cudaEventCreateWithFlags(&_stop_comm, 0); - } - - ~UbufCommOverlap() { - cudaEventDestroy(_stop_comm); - cudaEventDestroy(_start_comm); - cudaEventDestroy(_start_d2dcopy); - cudaEventDestroy(_stop_compute); - cudaEventDestroy(_start_compute); - - for (size_t i = 0; i < _stream_compute.size(); i++) cudaStreamDestroy(_stream_compute[i]); - - if (comm_created) { -#ifdef NVTE_UB_WITH_MPI - destroy_communicator_mpi(_ub_comm); -#else - destroy_communicator(_ub_comm); -#endif - comm_created = false; - } - } - - /* - ** Bulk GEMM + COMM - ** This function assumes the communication input is pre-copied to _ubuf - */ - std::vector bulk_overlap( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, - int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, int comm_type, - at::Tensor rs_output) { - int ori_sms = _ub_comm->sms; - _ub_comm->use_ce = _use_ce; - _ub_comm->sms = _num_comm_sm; - _ub_comm->cga_size = _cga_size; - // Get the current userbuf offset - char *ubuf_wt_ptr = reinterpret_cast(_ubuf.data_ptr()); - int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size - COMM_TYPE _comm_type = static_cast(comm_type); - if (_comm_type == COMM_TYPE::RS) { - ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); - } - - // Catch up the default torch stream - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, (cudaStream_t)stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); - - // Communication: AG and RS - if (_comm_type == COMM_TYPE::AG) { - allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, (cudaStream_t)_stream_comm); - } else if (_comm_type == COMM_TYPE::RS) { - if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); - comm_elements *= 2; - float *scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - assert(rs_output.numel() == _ubuf.numel() / _tp_size); - assert(rs_output.size(0) == _ubuf.size(0) / _tp_size); - assert(rs_output.element_size() == 2); - char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); - reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, scale_inv_ptr, _ub_reg, 0, - comm_elements, _ub_comm, - (cudaStream_t)_stream_comm); - } else { - reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, - (cudaStream_t)_stream_comm); - } - } else { - NVTE_ERROR("Not supported communication type."); - } - - if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - - if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; - - assert(pre_gelu_out.numel() == 0); - te_gemm(A, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, D, D_scale, - D_type, D_amax, bias, bias_type, pre_gelu_out, grad, workspace, workspaceSize, - accumulate, use_split_accumulator, _math_sms); - - NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0)); - - // Generate output tensor from userbuf data pointer - int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; - int output_c_dim1 = _ubuf.size(1); - output_tensor = torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options()); - _ub_comm->sms = ori_sms; - - return {D, output_tensor}; - } // bulk_overlap - - /* - ** Split FPROP GEMM + ReduceScatter - */ - void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, - at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, - bool grad, at::Tensor workspace, size_t workspaceSize, - bool accumulate, bool use_split_accumulator, bool gemm_overlap, - at::Tensor rs_output) { - int ori_sms = _ub_comm->sms; - _ub_comm->use_ce = _use_ce; - _ub_comm->sms = _num_comm_sm; - _ub_comm->cga_size = _cga_size; - // Get GEMM dimensions - int m = A.size(0); - int k = A.size(1); - int n = B.size(0); - int m_chunk = m / _num_splits; - int workspace_size_chunk = workspaceSize / _stream_compute.size(); - - // Get input, output, and workspace data pointers - char *input_a_chunk_ptr = reinterpret_cast(A.data_ptr()); - char *output_buf_chunk_ptr = reinterpret_cast(_ubuf.data_ptr()); - char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); - int *counter_ptr = reinterpret_cast(counter.data_ptr()); - char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); - - // Catch up the default torch stream - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0)); - - if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - - if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; - - assert(pre_gelu_out.numel() == 0); - - torch::Tensor input_a = torch::from_blob(input_a_chunk_ptr, {m, k}, A.options()); - torch::Tensor output_d = torch::from_blob(output_buf_chunk_ptr, {n, m}, _ubuf.options()); - // torch::zeros({n, m}, _ubuf.options()); - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); - at::cuda::setCurrentCUDAStream(_stream_compute[0]); - te_atomic_gemm(input_a, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, - output_d, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, - workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, - _math_sms, _num_splits /*m_split*/, 0 /*n_split*/, true /*gemm_producer*/, - counter); - - for (int i = 0; i < _num_splits; i++) { - const char *env_p = std::getenv("NVTE_RS_STRIDED_ATOMIC"); - if (env_p != nullptr && env_p[0] == '1') { - if (i == _num_splits - 1) { - _ub_comm->sms = UB_MAX_SM; - } - if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - D_type, fp8_type, - reducescatter2_userbuff_strided_atomic_fp8( - rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, m, - _num_splits, &counter_ptr[i], _ub_comm, (cudaStream_t)_stream_comm);); - } else { - reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, - _num_splits, &counter_ptr[i], _ub_comm, - (cudaStream_t)_stream_comm); - } - } else if (env_p != nullptr && env_p[0] == '2') { - if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - D_type, fp8_type, - reducescatter2_userbuff_strided_multiatomic_fp8( - rs_output_ptr, d_scale_inv_ptr, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits, - counter_ptr, _ub_comm, (cudaStream_t)_stream_comm);); - } else { - reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk, n, - m, _num_splits, counter_ptr, _ub_comm, - (cudaStream_t)_stream_comm); - } - break; - } else { - assert(_ubuf.element_size() != 1); - consumer(counter_ptr, i, (cudaStream_t)_stream_comm); - reducescatter2_userbuff_strided(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, - _ub_comm, (cudaStream_t)_stream_comm); - } - - rs_output_ptr += m_chunk * rs_output.element_size(); - } - - _ub_comm->sms = ori_sms; - NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[0])); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0)); - at::cuda::setCurrentCUDAStream(stream_main); - - return; - } // split_overlap_rs - - /* - ** Split FPROP GEMM + ReduceScatter - */ - void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, - at::Tensor bias, transformer_engine::DType bias_type, - at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, - bool gemm_overlap, at::Tensor rs_output) { - // Get GEMM dimensions - int ori_sms = _ub_comm->sms; - _ub_comm->use_ce = _use_ce; - _ub_comm->sms = _num_comm_sm; - _ub_comm->cga_size = _cga_size; - int m = A.size(0); - int k = A.size(1); - int n = B.size(0); - int m_chunk = m / _num_splits; - int input_a_chunk_size = m_chunk * k; - int output_chunk_size = n * m_chunk; - int workspace_size_chunk = workspaceSize / _stream_compute.size(); - - // Get input, output, and workspace data pointers - char *input_a_chunk_ptr = reinterpret_cast(A.data_ptr()); - char *output_buf_chunk_ptr = reinterpret_cast(_ubuf.data_ptr()); - char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); - - char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); - - // Catch up the default torch stream - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); - for (size_t i = 0; i < _stream_compute.size(); i++) { - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0)); - } - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0)); - - if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - - if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; - - assert(pre_gelu_out.numel() == 0); - - if (gemm_overlap) { - torch::Tensor input_a_chunk = torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); - torch::Tensor output_chunk = - torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); - at::cuda::setCurrentCUDAStream(_stream_compute[0]); - te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, - output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, - workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, _math_sms); - - for (int i = 1; i < _num_splits; i++) { - input_a_chunk_ptr += input_a_chunk_size * B.element_size(); - output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); - - torch::Tensor input_a_chunk = - torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); - torch::Tensor output_chunk = - torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, - {workspace_size_chunk}, workspace.options()); - at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); - te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, - output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, - workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, - _math_sms); - - NVTE_CHECK_CUDA(cudaEventRecord( - _start_comm, (cudaStream_t)_stream_compute[(i - 1) % _stream_compute.size()])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); - - // Communication chunk - if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - D_type, fp8_type, - reducescatter2_userbuff_stridedoutput_fp8( - rs_output_ptr, d_scale_inv_ptr, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, - m, _ub_comm, (cudaStream_t)_stream_comm);); - } else { - reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (i - 1) * output_chunk_size, - m_chunk, n, m, _ub_comm, - (cudaStream_t)_stream_comm); - } - - rs_output_ptr += m_chunk * rs_output.element_size(); - } - int last_compute_stream_id = - (_num_splits + _stream_compute.size() - 1) % _stream_compute.size(); - NVTE_CHECK_CUDA( - cudaEventRecord(_start_comm, (cudaStream_t)_stream_compute[last_compute_stream_id])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); - - // Last communication chunk with max SM - _ub_comm->sms = UB_MAX_SM; - if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - D_type, fp8_type, - reducescatter2_userbuff_stridedoutput_fp8( - rs_output_ptr, d_scale_inv_ptr, _ub_reg, (_num_splits - 1) * output_chunk_size, - m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm);); - } else { - reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, - (_num_splits - 1) * output_chunk_size, m_chunk, n, m, - _ub_comm, (cudaStream_t)_stream_comm); - } - } else { - for (int i = 0; i < _num_splits; i++) { - torch::Tensor input_a_chunk = - torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); - torch::Tensor output_chunk = - torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, - {workspace_size_chunk}, workspace.options()); - at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); - te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, - output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, - workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, - _math_sms); - - NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, - (cudaStream_t)_stream_compute[i % _stream_compute.size()])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); - - // Communication chunk. Uses MAX_SM at the last chunk - if (i == _num_splits - 1) { - _ub_comm->sms = UB_MAX_SM; - } - if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - D_type, fp8_type, - reducescatter2_userbuff_stridedoutput_fp8( - rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * output_chunk_size, m_chunk, n, m, - _ub_comm, (cudaStream_t)_stream_comm);); - } else { - reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size, - m_chunk, n, m, _ub_comm, - (cudaStream_t)_stream_comm); - } - rs_output_ptr += m_chunk * rs_output.element_size(); - input_a_chunk_ptr += input_a_chunk_size * B.element_size(); - output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); - } - } - for (size_t i = 0; i < _stream_compute.size(); i++) { - NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); - } - _ub_comm->sms = ori_sms; - NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0)); - at::cuda::setCurrentCUDAStream(stream_main); - - return; - } // split_overlap_rs - - void set_ubuf_scale_inv(const torch::Tensor &scale_inv) { - _ubuf_scale_inv = scale_inv; - _ubuf_scale_inv_initialized = true; - } - - bool is_fp8_ubuf() { return (_ubuf.element_size() == 1); } - /* - ** Helper function to copy input to _ubuf - */ - void copy_input_to_ubuf(torch::Tensor input, int comm_type) { - char *ubuf_ptr = reinterpret_cast(_ubuf.data_ptr()); - COMM_TYPE _comm_type = static_cast(comm_type); - if (_comm_type == COMM_TYPE::AG) { - if ((input.numel() * _tp_size) != _ubuf.numel() || - input.element_size() != _ubuf.element_size()) { - NVTE_ERROR("input and ubuf size do not match!"); - } - ubuf_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); - } else { - if (input.numel() != _ubuf.numel() || input.element_size() != _ubuf.element_size()) { - NVTE_ERROR("input and ubuf size do not match!"); - } - } - - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - NVTE_CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0)); - NVTE_CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input.data_ptr(), - input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, - (cudaStream_t)_stream_comm)); - } - - torch::Tensor &get_ubuf_output(int comm_type) { - char *ubuf_wt_ptr = reinterpret_cast(_ubuf.data_ptr()); - COMM_TYPE _comm_type = static_cast(comm_type); - if (_comm_type != COMM_TYPE::AG && _comm_type != COMM_TYPE::RS) NVTE_ERROR("Invalid comm_type"); - if (_comm_type == COMM_TYPE::RS) - ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); - int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; - int output_c_dim1 = _ubuf.size(1); - output_tensor = torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options()); - return output_tensor; - } - - bool is_atomic_gemm() { return _atomic_gemm; } - bool is_p2p_overlap() { return false; } -}; // UbufCommOverlap - -struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { - int _tp_id; - int _tp_size; - int _ub_reg, _ub_reg2; - int _next_rank, _prev_rank, _rank, _rank_round_tp; - int _aggregate2; - int _math_sms; - int _self_chunk_id; - void *_ubuf_ptr; - torch::Tensor _ubuf; - torch::Tensor counter; - torch::Tensor _ubuf_scale_inv; - bool _ubuf_scale_inv_initialized; - std::vector _ubufs; - at::cuda::CUDAStream _stream_send = at::cuda::getStreamFromPool(true); - at::cuda::CUDAStream _stream_recv = at::cuda::getStreamFromPool(true); - std::vector _stream_compute; - cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_send, _stop_recv; - int _use_ce; - int _num_comm_sm; - int _cga_size; - bool _atomic_gemm; - - UbufP2PCommOverlap(torch::Tensor sample, int myrank, int numranks, int mylocal, int numlocal, - int mynode, int numnodes, int tp_size, int num_comm_sm, int comm_cga_size, - bool set_sm_margin, bool aggregate2, int num_max_streams, - bool is_reduce_scatter, bool atomic_gemm, bool use_ce, - UbufBootstrapCallbacks &callbacks) { - // Initialize userbuf communicator - if (!comm_created) { - if (myrank == 0) { - printf("!!! [UB] Create Userbuffers Communicator\n"); - } -#ifdef NVTE_UB_WITH_MPI - create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1); -#else - create_communicator_grouped2( - &_ub_comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, - std::bind(&UbufBootstrapCallbacks::ub_allgather, callbacks, _1, _2, _3, _4, _5), - std::bind(&UbufBootstrapCallbacks::ub_barrier, callbacks, _1), 1, 1, tp_size, 1); -#endif - comm_created = true; - } - _use_ce = use_ce; - _num_comm_sm = num_comm_sm; - _cga_size = comm_cga_size; - - // Create workspace tensor with userbuffer - int ubuf_bytes = sample.numel() * sample.element_size(); - int ubuf_chunk_bytes = ubuf_bytes / tp_size; - int num_ubuf_chunks = tp_size; - if (is_reduce_scatter) { - // GEMM + RS overlap: Allocate `2 x tp_size - 1` buffers to hold recieved GEMM chunk - // outputs for reduction at the end of the pipelining. - ubuf_bytes = static_cast(ubuf_bytes / tp_size * (tp_size * 2 - 1)); - num_ubuf_chunks = static_cast(tp_size * 2 - 1); - } - - _ub_reg = register_user_buffer_collective(reinterpret_cast(&_ubuf_ptr), ubuf_bytes, - _ub_comm, true); - _ubuf = torch::from_blob( - _ubuf_ptr, {sample.size(0) / tp_size * num_ubuf_chunks, sample.size(1)}, sample.options()); - if (_ub_comm->myrank == 0) { - printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg); - } - - // Create tensor chunks for easy management - char *ubuf_byte_ptr = reinterpret_cast(_ubuf.data_ptr()); - for (int i = 0; i < num_ubuf_chunks; i++) { - auto ubuf_chunk = torch::from_blob(ubuf_byte_ptr, {sample.size(0) / tp_size, sample.size(1)}, - sample.options()); - _ubufs.push_back(std::move(ubuf_chunk)); - ubuf_byte_ptr += ubuf_chunk_bytes; - } - - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - for (int i = 0; i < std::min(num_max_streams, tp_size); i++) { - cudaStream_t stream; - cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, -1); - _stream_compute.push_back( - at::cuda::getStreamFromExternal(stream, stream_main.device_index())); - } - - // Set the number of SMs for GEMM with margin - cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, 0); - _math_sms = (set_sm_margin) ? prop.multiProcessorCount - num_comm_sm : prop.multiProcessorCount; - _math_sms -= transformer_engine::getenv("NVTE_EXT_MARGIN_SM", 0); - - _tp_size = tp_size; - _aggregate2 = aggregate2; - - _rank = _ub_comm->myrank; - _tp_id = (_rank % _tp_size); - _rank_round_tp = (_rank / _tp_size) * _tp_size; - _next_rank = (_tp_size + _rank + 1) % _tp_size + _rank_round_tp; - _prev_rank = (_tp_size + _rank + -1) % _tp_size + _rank_round_tp; - _ubuf_scale_inv_initialized = false; - - _atomic_gemm = atomic_gemm; - _self_chunk_id = _tp_id; - if (_atomic_gemm) { - auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA); - counter = torch::zeros({_tp_size * 2}, counter_options); - counter.index_put_({Slice(None, _tp_size)}, 1); - - if (!is_reduce_scatter) { - const char *env_p = std::getenv("NVTE_AG_P2P_MULTI_ATOMIC"); - if (_rank == 0 && env_p != nullptr) { - if (env_p[0] == '1') { - _use_ce = 0; - _ub_comm->push = 1; - printf("!!userbuffers_sendrecv_multi_atomic_shuffle\n"); - } - } - _self_chunk_id = 0; - counter.index_put_({_self_chunk_id}, 0); - } - } - - // CUDA event creation - cudaEventCreateWithFlags(&_start_compute, 0); - cudaEventCreateWithFlags(&_stop_compute, 0); - cudaEventCreateWithFlags(&_start_comm, 0); - cudaEventCreateWithFlags(&_stop_send, 0); - cudaEventCreateWithFlags(&_stop_recv, 0); - } - - ~UbufP2PCommOverlap() { - cudaEventDestroy(_stop_recv); - cudaEventDestroy(_stop_send); - cudaEventDestroy(_start_comm); - cudaEventDestroy(_stop_compute); - cudaEventDestroy(_start_compute); - - for (size_t i = 0; i < _stream_compute.size(); i++) cudaStreamDestroy(_stream_compute[i]); - - if (comm_created) { -#ifdef NVTE_UB_WITH_MPI - destroy_communicator_mpi(_ub_comm); -#else - destroy_communicator(_ub_comm); -#endif - comm_created = false; - } - } - - /* - ** Split AllGather + AtomicGEMM using P2P communication - ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is - *needed to have AG outputs - ** in each rank to be in the contiguous memory space after all ring exchange - *phases. - */ - torch::Tensor atomic_gemm_overlap_ag( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, - int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { - int ori_sms = _ub_comm->sms; - _ub_comm->use_ce = _use_ce; - _ub_comm->sms = _num_comm_sm; - _ub_comm->cga_size = _cga_size; - // Get GEMM dimensions between TN and NN input layouts - const int m = (transa) ? A.size(0) : A.size(1); - const int n = _ubuf.size(0); - const int n_chunk = n / _tp_size; - - // Get communication and GEMM output chunk sizes - const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); - - // Create an GEMM output buffer with N+1 chunks in a contiguous memory - torch::Tensor D_buffer = torch::empty({n_chunk * (_tp_size + 1), m}, D.options()); - D = torch::from_blob(D_buffer.data_ptr(), {D.size(0), D.size(1)}, D.options()); - - // Get output and workspace data pointers - char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); - int *counter_ptr = reinterpret_cast(counter.data_ptr()); - int workspace_size_chunk = workspaceSize / _stream_compute.size(); - - if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - - if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; - - assert(pre_gelu_out.numel() == 0); - - // Catch up the default torch stream - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); - - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); - - for (int i = 0; i < _tp_size - 1; i++) { - // Set the userbuffer id. Buffer under send is the input for the current - // GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to - // have the AG output in all ranks to be contiguous after the ring - // exchanges - int send_chunk_id = i; - int recv_chunk_id = i + 1; - int send_offset = comm_bytes * send_chunk_id; - int recv_offset = comm_bytes * recv_chunk_id; - - const char *env_p = std::getenv("NVTE_AG_P2P_MULTI_ATOMIC"); - if (env_p != nullptr && env_p[0] == '1') { - if (i == 0) { - _ub_comm->use_ce = 0; - userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes, - _ub_comm, _next_rank, _prev_rank, _tp_size, counter_ptr, - true, (cudaStream_t)_stream_recv); - } - } else { - userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, - _next_rank, (cudaStream_t)_stream_recv); - userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, - _prev_rank, (cudaStream_t)_stream_recv); - producer(counter_ptr, recv_chunk_id, (cudaStream_t)_stream_recv); - } - if (i == 0) { - te_atomic_gemm(A, A_scale_inverse, A_type, transa, _ubuf, B_scale_inverse, B_type, transb, - D, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, - workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, - _math_sms, 0, _tp_size, false, counter); - } - } - - // Store the input activation for backprop - if (B_copy.numel() > 0) { - assert(B_copy.numel() == _ubufs[_self_chunk_id].numel()); - assert(B_copy.element_size() == _ubufs[_self_chunk_id].element_size()); - NVTE_CHECK_CUDA( - cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_self_chunk_id].data_ptr(), - _ubufs[_self_chunk_id].numel() * _ubufs[_self_chunk_id].element_size(), - cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); - } - - // Reset atomic counters - consumer_batch(counter_ptr, 1, _tp_size, (cudaStream_t)stream_main); - - // Copy the first GEMM output chunk to the end chunk position of D_buffer - char *src_ptr = reinterpret_cast(D_buffer.data_ptr()); - NVTE_CHECK_CUDA(cudaMemcpyAsync(src_ptr + (D.numel() * D.element_size()), src_ptr, - n_chunk * m * D.element_size(), cudaMemcpyDeviceToDevice, - (cudaStream_t)stream_main)); - // Return the last N rows of D_buffer - _ub_comm->sms = ori_sms; - torch::Tensor D_return = D_buffer.narrow(0, n_chunk, n); - return D_return; - } // atomic_gemm_overlap_ag - - /* - ** Split AllGather + GEMM using P2P communication - ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is - *needed to have AG outputs - ** in each rank to be in the contiguous memory space after all ring exchange - *phases. - */ - torch::Tensor split_overlap_ag(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, - at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, - bool grad, at::Tensor workspace, size_t workspaceSize, - bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { - int ori_sms = _ub_comm->sms; - _ub_comm->use_ce = _use_ce; - _ub_comm->sms = _num_comm_sm; - _ub_comm->cga_size = _cga_size; - // Get GEMM dimensions between TN and NN input layouts - const int m = (transa) ? A.size(0) : A.size(1); - const int k = (transa) ? A.size(1) : A.size(0); - const int n_chunk = _ubufs[0].size(0); - - // Get communication and GEMM output chunk sizes - const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); - const bool do_gelu = pre_gelu_out.numel() > 0; - const int output_chunk_bytes = (n_chunk * m) * D.element_size(); - const int aux_chunk_bytes = do_gelu ? (n_chunk * m) * pre_gelu_out.element_size() : 0; - - // Get output and workspace data pointers - char *output_ptr = reinterpret_cast(D.data_ptr()); - char *pre_gelu_out_ptr = reinterpret_cast(pre_gelu_out.data_ptr()); - char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); - int workspace_size_chunk = workspaceSize / _stream_compute.size(); - - if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - - if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; - - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); - - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); - for (size_t i = 0; i < _stream_compute.size(); i++) { - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0)); - } - if (_aggregate2) { - const int num_steps = _tp_size / 2; - char *input_b_ptr = reinterpret_cast(_ubuf.data_ptr()); - - // Initial 1X input chunk exchange between neighboring peers - int send_chunk_id = _tp_id; - int recv_chunk_id = (_tp_id % 2 == 0) ? _tp_id + 1 : _tp_id - 1; - int send_offset = comm_bytes * send_chunk_id; - int recv_offset = comm_bytes * recv_chunk_id; - int peer_rank = (_tp_id % 2 == 0) ? _next_rank : _prev_rank; - userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, peer_rank, - (cudaStream_t)_stream_send); - userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, peer_rank, - (cudaStream_t)_stream_recv); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _stop_recv, 0)); - - int local_rank_round2 = (_tp_id % 2 == 0) ? _tp_id : _tp_id - 1; - const int next_rank = (_tp_size + _tp_id + 2) % _tp_size + _rank_round_tp; - const int prev_rank = (_tp_size + _tp_id - 2) % _tp_size + _rank_round_tp; - - // Ring exchange of 2X inputs chunks - for (int i = 0; i < num_steps; i++) { - send_chunk_id = (_tp_size + local_rank_round2 - i * 2) % _tp_size; - recv_chunk_id = (_tp_size + local_rank_round2 - i * 2 - 2) % _tp_size; - send_offset = comm_bytes * send_chunk_id; - recv_offset = comm_bytes * recv_chunk_id; - - // GEMM - torch::Tensor input_b_chunk = - torch::from_blob(input_b_ptr + send_offset, {n_chunk * 2, k}, _ubuf.options()); - torch::Tensor output_chunk = torch::from_blob( - output_ptr + (send_chunk_id * output_chunk_bytes), {n_chunk * 2, m}, D.options()); - if (do_gelu) { - pre_gelu_out = torch::from_blob(pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes), - {n_chunk * 2, m}, pre_gelu_out.options()); - } - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, - {workspace_size_chunk}, workspace.options()); - at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); - te_gemm(A, A_scale_inverse, A_type, transa, input_b_chunk, B_scale_inverse, B_type, transb, - output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, - workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, - _math_sms); - - if (i < num_steps - 1) { - // P2P communication - userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes * 2, _ub_comm, - next_rank, (cudaStream_t)_stream_send); - userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm, - prev_rank, (cudaStream_t)_stream_recv); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent( - (cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); - } else if (B_copy.numel() > 0) { - assert(B_copy.numel() == _ubufs[_tp_id].numel()); - assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); - NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(), - _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), - cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); - } - } - } else { - for (int i = 0; i < _tp_size; i++) { - // Set the userbuffer id. Buffer under send is the input for the current - // GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to - // have the AG output in all ranks to be contiguous after the ring - // exchanges - int send_chunk_id = (_tp_size + _tp_id - i) % _tp_size; - int recv_chunk_id = (_tp_size + _tp_id - i - 1) % _tp_size; - int send_offset = comm_bytes * send_chunk_id; - int recv_offset = comm_bytes * recv_chunk_id; - - // GEMM - torch::Tensor output_chunk = torch::from_blob( - output_ptr + (send_chunk_id * output_chunk_bytes), {n_chunk, m}, D.options()); - if (do_gelu) { - pre_gelu_out = torch::from_blob(pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes), - {n_chunk, m}, pre_gelu_out.options()); - } - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, - {workspace_size_chunk}, workspace.options()); - at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); - te_gemm(A, A_scale_inverse, A_type, transa, _ubufs[send_chunk_id], B_scale_inverse, B_type, - transb, output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, - workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, - _math_sms); - - if (i < _tp_size - 1) { - // P2P communication - userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, - _next_rank, (cudaStream_t)_stream_send); - userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, - _prev_rank, (cudaStream_t)_stream_recv); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _stop_recv, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent( - (cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); - } else if (B_copy.numel() > 0) { - assert(B_copy.numel() == _ubufs[_tp_id].numel()); - assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); - NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(), - _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), - cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); - } - } - } - for (size_t i = 0; i < _stream_compute.size(); i++) { - NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); - } - NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); - at::cuda::setCurrentCUDAStream(stream_main); - _ub_comm->sms = ori_sms; - - return D; - } // split_overlap_ag - - /* - ** Split ReduceScatter + GEMM using P2P communication - */ - void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, - at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, - bool grad, at::Tensor workspace, size_t workspaceSize, - bool accumulate, bool use_split_accumulator, at::Tensor rs_output) { - int ori_sms = _ub_comm->sms; - _ub_comm->use_ce = _use_ce; - _ub_comm->sms = _num_comm_sm; - _ub_comm->cga_size = _cga_size; - - // Get communication and GEMM input chunk sizes - const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); - - // Get input and workspace data pointers - char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); - int *counter_ptr = reinterpret_cast(counter.data_ptr()); - int workspace_size_chunk = workspaceSize / _stream_compute.size(); - - if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - - if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; - - // Catch up the main stream - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); - - // Atomic GEMM - // Process GEMM chunks in the order that AG+GEMM places the output chunks. - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); - te_atomic_gemm(A, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, _ubuf, - D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, workspace_chunk, - workspace_size_chunk, accumulate, use_split_accumulator, _math_sms, 0, _tp_size, - true, counter); - - // P2P communication chunk - for (int i = 1; i < _tp_size; i++) { - int send_chunk_id = i - 1; - int recv_chunk_id = send_chunk_id + _tp_size; - int send_offset = comm_bytes * send_chunk_id; - int recv_offset = comm_bytes * recv_chunk_id; - int send_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; - int recv_rank = (_tp_id + i) % _tp_size + _rank_round_tp; - - consumer(counter_ptr, send_chunk_id, (cudaStream_t)_stream_recv); - userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, send_rank, - (cudaStream_t)_stream_recv); - userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, recv_rank, - (cudaStream_t)_stream_recv); - } - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); - - // Reduce GEMM output chunks - char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].data_ptr()); - if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - D_type, fp8_type, - reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size, - _ubufs[0].numel(), (cudaStream_t)stream_main);); - } else { - torch::Tensor reduce_buf = torch::from_blob( - reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); - torch::sum_out(rs_output, reduce_buf, 0); - } - _ub_comm->sms = ori_sms; - } - - /* - ** Split ReduceScatter + GEMM using P2P communication - */ - void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, - at::Tensor bias, transformer_engine::DType bias_type, - at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, - at::Tensor rs_output) { - int ori_sms = _ub_comm->sms; - _ub_comm->use_ce = _use_ce; - _ub_comm->sms = _num_comm_sm; - _ub_comm->cga_size = _cga_size; - int k = A.size(1); - int n = B.size(0); - - // Get communication and GEMM input chunk sizes - int n_chunk = n / _tp_size; - const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); - const int input_b_chunk_bytes = n_chunk * k * B.element_size(); - - // Get input and workspace data pointers - char *input_b_ptr = reinterpret_cast(B.data_ptr()); - char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); - int workspace_size_chunk = workspaceSize / _stream_compute.size(); - - if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - - if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; - - // Catch up the main stream - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); - for (size_t i = 0; i < _stream_compute.size(); i++) { - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0)); - } - - // GEMM and send/recv chunks - for (int i = 0; i < _tp_size; i++) { - // GEMM chunk - int input_b_chunk_id = (_tp_id + i + 1) % _tp_size; - char *input_b_chunk_ptr = input_b_ptr + (input_b_chunk_id * input_b_chunk_bytes); - torch::Tensor input_b_chunk = torch::from_blob(input_b_chunk_ptr, {n_chunk, k}, B.options()); - // Store the last GEMM chunk output to the recieve buffer. - torch::Tensor workspace_chunk = - torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, - {workspace_size_chunk}, workspace.options()); - at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); - te_gemm(A, A_scale_inverse, A_type, transa, input_b_chunk, B_scale_inverse, B_type, transb, - _ubufs[i], D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, - workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, _math_sms); - - if (i > 0) { - // P2P communication chunk - int send_offset = comm_bytes * (i - 1); - int recv_offset = comm_bytes * (i - 1 + _tp_size); - int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp; - int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; - NVTE_CHECK_CUDA(cudaEventRecord( - _start_comm, (cudaStream_t)_stream_compute[(i - 1) % _stream_compute.size()])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_comm, 0)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_comm, 0)); - userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, - send_rank, (cudaStream_t)_stream_send); - userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, - recv_rank, (cudaStream_t)_stream_recv); - } - } - at::cuda::setCurrentCUDAStream(stream_main); - for (size_t i = 0; i < _stream_compute.size(); i++) { - NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); - } - NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); - - // Reduce GEMM output chunks - char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].data_ptr()); - if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { - assert(_ubuf_scale_inv_initialized); - float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); - char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - D_type, fp8_type, - reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size, - _ubufs[0].numel(), (cudaStream_t)stream_main);); - } else { - torch::Tensor reduce_buf = torch::from_blob( - reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); - torch::sum_out(rs_output, reduce_buf, 0); - } - _ub_comm->sms = ori_sms; - } - - /* - ** Copy input to _ubufs[0] - */ - void copy_input_to_ubuf(torch::Tensor input, bool chunk) { - at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - if (chunk) { - // Copy input to the target ubuf chunk by rank offset - if (input.numel() != _ubufs[0].numel() || input.element_size() != _ubufs[0].element_size()) { - NVTE_ERROR("input and ubuf size do not match!"); - } - NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].data_ptr(), input.data_ptr(), - input.numel() * input.element_size(), - cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main)); - } else { - if (input.numel() != _ubuf.numel() || input.element_size() != _ubuf.element_size()) { - NVTE_ERROR("input and ubuf size do not match!"); - } - NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.data_ptr(), input.data_ptr(), - input.numel() * input.element_size(), - cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main)); - } - } - - torch::Tensor get_ubuf_output(int comm_type) { - char *ubuf_wt_ptr = reinterpret_cast(_ubuf.data_ptr()); - COMM_TYPE _comm_type = static_cast(comm_type); - if (_comm_type != COMM_TYPE::AG && _comm_type != COMM_TYPE::RS) NVTE_ERROR("Invalid comm_type"); - if (_comm_type == COMM_TYPE::RS) - ubuf_wt_ptr += _ubuf.numel() / _tp_size * _self_chunk_id * _ubuf.element_size(); - int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; - int output_c_dim1 = _ubuf.size(1); - return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options()); - } - - void set_ubuf_scale_inv(const torch::Tensor &scale_inv) { - _ubuf_scale_inv = scale_inv; - _ubuf_scale_inv_initialized = true; - } - - bool is_fp8_ubuf() { return (_ubuf.element_size() == 1); } - bool is_atomic_gemm() { return _atomic_gemm; } - bool is_p2p_overlap() { return true; } -}; // UbufP2PCommOverlap - -} // namespace ubuf - -#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMM_GEMM_OVERLAP_H_ diff --git a/transformer_engine/pytorch/csrc/common.cu b/transformer_engine/pytorch/csrc/common.cpp similarity index 99% rename from transformer_engine/pytorch/csrc/common.cu rename to transformer_engine/pytorch/csrc/common.cpp index 2d8e602c5b..2ac190863c 100644 --- a/transformer_engine/pytorch/csrc/common.cu +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -5,6 +5,7 @@ ************************************************************************/ #include "common.h" + #include "transformer_engine/transformer_engine.h" transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 04a1193a71..175a7b0e90 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -37,12 +38,14 @@ #include #include +#include #include #include #include #include #include #include +#include #include #include "common/util/logging.h" diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index c30e583178..3b49ece4a3 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -7,6 +7,8 @@ #ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ #define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ +#include + #include "common.h" #include "common/common.h" @@ -208,6 +210,12 @@ std::vector fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, int scale_offset = 0, int amax_offset = 0, int scale_inv_offset = 0); +void fused_dswiglu_cast_transpose(at::Tensor grad_output, at::Tensor input, at::Tensor grad_input, + at::Tensor grad_input_transpose, at::Tensor scale, + at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype, int scale_offset = 0, + int amax_offset = 0, int scale_inv_offset = 0); + void fused_multi_cast_transpose(std::vector input_list, std::vector scale_list, std::vector cast_output_list, @@ -504,4 +512,184 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, std::vector input_row_list, std::vector padded_input_row_list); +/*************************************************************************************************** + * Comm+GEMM Overlap Wrappers + **************************************************************************************************/ + +class CommOverlapHelper : torch::CustomClassHolder { + private: + bool initialized{false}; + bool backend_is_nccl{false}; + std::map pgs; + + public: + int myrank = -1; + int numranks = -1; + int mylocal = -1; + int numlocal = -1; + int mynode = -1; + int numnodes = -1; + + CommOverlapHelper(); + + CommOverlapHelper(c10d::ProcessGroup *world_group, + std::optional intra_node_group, + std::optional inter_node_group); + + ~CommOverlapHelper(); + + void ub_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes, + ExtComm comm); + + void ub_barrier(ExtComm comm); +}; + +class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOverlapBase { + private: + torch::Tensor _ubuf_torch; + torch::Tensor _ubuf_counter; + + public: + CommOverlap(const std::vector &buffer_shape, at::ScalarType buffer_dtype, + CommOverlapHelper *helper, int tp_size, int num_splits = 3, + int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, + int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false); + + void set_ubuf_scale_inv(torch::Tensor scale_inv) { + assert(scale_inv.numel()); + assert(scale_inv.scalar_type() == torch::kFloat32); + transformer_engine::CommOverlapBase::set_ubuf_scale_inv( + reinterpret_cast(scale_inv.data_ptr())); + } + + void copy_input_to_ubuf(torch::Tensor input, int comm_type); + + torch::Tensor get_ubuf_output(int comm_type); + + /* + ** Bulk GEMM + COMM + ** This function assumes the communication input is pre-copied to _ubuf + */ + std::vector bulk_overlap( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, + int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, + transformer_engine::CommOverlapType comm_type, at::Tensor rs_output); + + /* + ** Split FPROP GEMM + ReduceScatter + */ + void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, int64_t B_fp8_tensor, + transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, + at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, + bool grad, at::Tensor workspace, size_t workspaceSize, + bool accumulate, bool use_split_accumulator, bool gemm_overlap, + at::Tensor rs_output); + + /* + ** Split FPROP GEMM + ReduceScatter + */ + void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, int64_t B_fp8_tensor, + transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, + at::Tensor bias, transformer_engine::DType bias_type, + at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, + bool gemm_overlap, at::Tensor rs_output); +}; // CommOverlap + +class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::CommOverlapP2PBase { + private: + torch::Tensor _ubuf_torch; + torch::Tensor _ubuf_counter; + + public: + CommOverlapP2P(const std::vector &buffer_shape, at::ScalarType buffer_dtype, + CommOverlapHelper *helper, int tp_size, + transformer_engine::CommOverlapType comm_type, + int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, + int num_comm_sm = 3, bool set_sm_margin = true, bool atomic_gemm = false, + bool use_ce = true, bool aggregate = false); + + void set_ubuf_scale_inv(torch::Tensor scale_inv) { + assert(scale_inv.numel()); + assert(scale_inv.scalar_type() == torch::kFloat32); + transformer_engine::CommOverlapP2PBase::set_ubuf_scale_inv( + reinterpret_cast(scale_inv.data_ptr())); + } + + void copy_input_to_ubuf(torch::Tensor input, bool chunk); + + torch::Tensor get_ubuf_output(int comm_type); + + /* + ** Split AllGather + AtomicGEMM using P2P communication + ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is + *needed to have AG outputs + ** in each rank to be in the contiguous memory space after all ring exchange + *phases. + */ + void atomic_gemm_overlap_ag(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, int64_t B_fp8_tensor, + transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, + at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, + bool grad, at::Tensor workspace, size_t workspaceSize, + bool accumulate, bool use_split_accumulator, at::Tensor B_copy); + + /* + ** Split AllGather + GEMM using P2P communication + ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is + *needed to have AG outputs + ** in each rank to be in the contiguous memory space after all ring exchange + *phases. + */ + void split_overlap_ag(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, int64_t B_fp8_tensor, + transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, + at::Tensor bias, transformer_engine::DType bias_type, + at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, + at::Tensor B_copy); + + /* + ** Split ReduceScatter + GEMM using P2P communication + */ + void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, int64_t B_fp8_tensor, + transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, + at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, + bool grad, at::Tensor workspace, size_t workspaceSize, + bool accumulate, bool use_split_accumulator, at::Tensor rs_output); + + /* + ** Split ReduceScatter + GEMM using P2P communication + */ + void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, int64_t B_fp8_tensor, + transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, + at::Tensor bias, transformer_engine::DType bias_type, + at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, + at::Tensor rs_output); +}; // CommOverlapP2P + #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cu b/transformer_engine/pytorch/csrc/extensions/activation.cpp similarity index 100% rename from transformer_engine/pytorch/csrc/extensions/activation.cu rename to transformer_engine/pytorch/csrc/extensions/activation.cpp diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cu b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp similarity index 100% rename from transformer_engine/pytorch/csrc/extensions/apply_rope.cu rename to transformer_engine/pytorch/csrc/extensions/apply_rope.cpp diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cu b/transformer_engine/pytorch/csrc/extensions/cast.cpp similarity index 100% rename from transformer_engine/pytorch/csrc/extensions/cast.cu rename to transformer_engine/pytorch/csrc/extensions/cast.cpp diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp new file mode 100644 index 0000000000..d212d13516 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -0,0 +1,480 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../extensions.h" + +#define HALF_BYTES 2 +#define UB_MAX_SM 32 + +using namespace torch::indexing; +using namespace std::placeholders; + +namespace te = transformer_engine; + +#define MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inv, A_fp8_index, A_type, B, B_scale_inv, \ + B_fp8_index, B_type, D, D_amax, D_scale, D_type, bias, \ + bias_type, pre_gelu_out, workspace) \ + A = A.contiguous(); \ + void *A_scale_inv_ptr = nullptr; \ + if (te::is_fp8_dtype(A_type)) { \ + assert(A_scale_inv.numel()); \ + A_scale_inv_ptr = A_scale_inv[A_fp8_index].data_ptr(); \ + } \ + auto A_ = makeTransformerEngineTensor( \ + A.data_ptr(), {static_cast(A.size(0)), static_cast(A.size(1))}, A_type, \ + nullptr, nullptr, A_scale_inv_ptr); \ + B = B.contiguous(); \ + void *B_scale_inv_ptr = nullptr; \ + if (te::is_fp8_dtype(B_type)) { \ + assert(B_scale_inv.numel()); \ + B_scale_inv_ptr = B_scale_inv[B_fp8_index].data_ptr(); \ + } \ + auto B_ = makeTransformerEngineTensor( \ + B.data_ptr(), {static_cast(B.size(0)), static_cast(B.size(1))}, B_type, \ + nullptr, nullptr, B_scale_inv_ptr); \ + void *D_amax_ptr = nullptr; \ + void *D_scale_ptr = nullptr; \ + if (te::is_fp8_dtype(D_type)) { \ + assert(D_amax.numel()); \ + D_amax_ptr = D_amax.data_ptr(); \ + assert(D_scale.numel()); \ + D_scale_ptr = D_scale.data_ptr(); \ + } \ + auto D_ = makeTransformerEngineTensor( \ + D.data_ptr(), {static_cast(D.size(0)), static_cast(D.size(1))}, D_type, \ + D_amax_ptr, D_scale_ptr, nullptr); \ + auto bias_ = makeTransformerEngineTensor( \ + bias.data_ptr(), std::vector{static_cast(bias.size(0))}, bias_type); \ + const auto gelu_shape = (pre_gelu_out.data_ptr() == nullptr) \ + ? std::vector{static_cast(pre_gelu_out.size(0))} \ + : std::vector{static_cast(pre_gelu_out.size(0)), \ + static_cast(pre_gelu_out.size(1))}; \ + auto pre_gelu_out_ = makeTransformerEngineTensor( \ + pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); \ + auto workspace_ = makeTransformerEngineTensor( \ + workspace.data_ptr(), std::vector{static_cast(workspace.size(0))}, \ + te::DType::kByte); + +/*************************************************************************************************** + * CommOverlapHelper + **************************************************************************************************/ + +CommOverlapHelper::CommOverlapHelper() { +#ifndef NVTE_UB_WITH_MPI + NVTE_ERROR("Internal TE error: Dummy CommOverlapHelper init without NVTE_UB_WITH_MPI=1!"); +#endif +} // empty constructor for NVTE_UB_WITH_MPI=1 + +CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group, + std::optional intra_domain_group, + std::optional inter_domain_group) { +#ifndef NVTE_UB_WITH_MPI + pgs.insert({"world", world_group}); + myrank = pgs["world"]->getRank(); + numranks = pgs["world"]->getSize(); + c10d::ProcessGroup::BackendType backend = pgs["world"]->getBackendType(); + backend_is_nccl = (backend == c10d::ProcessGroup::BackendType::NCCL); + + if (intra_domain_group.has_value()) { + // Get local rank on node and number of local ranks + NVTE_CHECK(intra_domain_group.value()->getBackendType() == backend, + "Internal TE error: Intra-node group must be on the same backend (%s) as the world ", + "group!", pgs["world"]->getBackendName()); + pgs.insert({"intra", intra_domain_group.value()}); + mylocal = pgs["intra"]->getRank(); + numlocal = pgs["intra"]->getSize(); + + if (numlocal == numranks) { + // Intra-node group is same as the world group so there can only be 1 node + NVTE_CHECK( + mylocal == myrank, + "Internal TE error: Local rank must be equal to global rank when intra-node group size ", + "is equal to the world group size!"); + mynode = 0; + numnodes = 1; + } else { + // Intra-node group is different than the world group so there must be multiple nodes + NVTE_CHECK( + inter_domain_group.has_value(), + "Internal TE error: Inter-node group cannot be `None` when intra-node group is not ", + "identical to the world_group!"); + + // Get node ID and number of nodes + NVTE_CHECK( + inter_domain_group.value()->getBackendType() == backend, + "Internal TE error: Inter-node group must be on the same backend (%s) as the world ", + "group!", pgs["world"]->getBackendName()); + pgs.insert({"inter", inter_domain_group.value()}); + mynode = pgs["inter"]->getRank(); + numnodes = pgs["inter"]->getSize(); + } + } else { + // Intra-node group is not set so we assume there is only 1 node + mylocal = myrank; + numlocal = numranks; + pgs.insert({"intra", world_group}); + + mynode = 0; + numnodes = 1; + } + + initialized = true; +#else + NVTE_ERROR("Internal TE error: CommOverlapHelper cannot be initialized with valid PyTorch ", + "distributed process groups when TE is compiled with NVTE_UB_WITH_MPI=1!"); +#endif +} + +CommOverlapHelper::~CommOverlapHelper() { +#ifndef NVTE_UB_WITH_MPI + for (auto &pg : pgs) pg.second = nullptr; + backend_is_nccl = false; + initialized = false; +#endif +} + +void CommOverlapHelper::ub_allgather(void *globaldata, size_t globalbytes, void *localdata, + size_t localbytes, ExtComm group) { +#ifndef NVTE_UB_WITH_MPI + NVTE_CHECK(initialized, "Internal TE error: tex.CommOverlapHelper() is not initialized ", + "with valid process groups!"); + + auto localtensor = + torch::from_blob(localdata, {static_cast(localbytes / sizeof(uint8_t))}, + at::device(torch::kCPU).dtype(torch::kUInt8)); + auto localtmp = (backend_is_nccl) ? localtensor.cuda() : localtensor; + auto globaltensor = + torch::from_blob(globaldata, {static_cast(globalbytes / sizeof(uint8_t))}, + at::device(torch::kCPU).dtype(torch::kUInt8)); + auto globaltmp = (backend_is_nccl) ? globaltensor.cuda() : globaltensor; + + std::vector> globalchunks = {globaltmp.chunk(pgs[group]->getSize())}; + std::vector localchunk = {localtmp}; + auto work = pgs[group]->allgather(globalchunks, localchunk); + work->wait(); + + if (backend_is_nccl) { + globaltensor.copy_(globaltmp.cpu()); + globaltmp = torch::Tensor(); + localtmp = torch::Tensor(); + } +#else + NVTE_ERROR("Internal TE error: CommOverlapHelper::ub_allgather is a no-op when TE is compiled ", + "with NVTE_UB_WITH_MPI=1!"); +#endif +} + +void CommOverlapHelper::ub_barrier(ExtComm group) { +#ifndef NVTE_UB_WITH_MPI + NVTE_CHECK(initialized, "Internal TE error: tex.CommOverlapHelper() is not initialized ", + "with valid process groups!"); + auto work = pgs[group]->barrier(); + work->wait(); +#else + NVTE_ERROR("Internal TE error: CommOverlapHelper::ub_barrier is a no-op when TE is compiled ", + "with NVTE_UB_WITH_MPI=1!"); +#endif +} + +/*************************************************************************************************** + * CommOverlap + **************************************************************************************************/ + +CommOverlap::CommOverlap(const std::vector &buffer_shape, at::ScalarType buffer_dtype, + CommOverlapHelper *helper, int tp_size, int num_splits, + int num_max_streams, int comm_cga_size, int num_comm_sm, + bool set_sm_margin, bool atomic_gemm) + : te::CommOverlapBase(buffer_shape, GetTransformerEngineDType(buffer_dtype), helper->myrank, + helper->numranks, helper->mylocal, helper->numlocal, helper->mynode, + helper->numnodes, tp_size, + std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), + std::bind(&CommOverlapHelper::ub_barrier, helper, _1), num_splits, + num_max_streams, comm_cga_size, num_comm_sm, set_sm_margin, atomic_gemm) { + // Even though we never use these PyTorch tensor wrappers directly, they're still necessary to + // for PyTorch to factor externally allocated memory into its memory pool and garbage collection + // threshold calculation. + _ubuf_torch = torch::from_blob( + _ubuf.dptr(), {static_cast(_ubuf.size(0)), static_cast(_ubuf.size(1))}, + at::device(torch::kCUDA).dtype(buffer_dtype)); + if (_atomic_gemm) { + _ubuf_counter = torch::from_blob(_counter.dptr(), {static_cast(_num_splits * 2)}, + at::device(torch::kCUDA).dtype(torch::kInt32)); + } +} + +/* +** Bulk GEMM + COMM +** This function assumes the communication input is pre-copied to _ubuf +*/ +std::vector CommOverlap::bulk_overlap( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, + at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, + at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, + te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, + te::CommOverlapType comm_type, at::Tensor rs_output) { + MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, + B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, + pre_gelu_out, workspace) + + auto rs_out_ = makeTransformerEngineTensor(rs_output); + cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); + te::CommOverlapBase::bulk_overlap(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, workspace_, + grad, accumulate, use_split_accumulator, comm_type, rs_out_, + stream_main); + + // Get the current userbuf offset + char *ubuf_wt_ptr = reinterpret_cast(_ubuf.dptr()); + if (comm_type == te::CommOverlapType::RS) { + ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); + } + + // Generate output tensor from userbuf data pointer + int output_c_dim0 = + (comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; + int output_c_dim1 = _ubuf.size(1); + auto output_tensor = + torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf_torch.options()); + + return {D, output_tensor}; +} // CommOverlap::bulk_overlap + +/* +** Split FPROP GEMM + ReduceScatter +*/ +void CommOverlap::atomic_gemm_overlap_rs( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, + at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, + at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, + te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, bool gemm_overlap, + at::Tensor rs_output) { + MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, + B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, + pre_gelu_out, workspace) + + auto rs_out_ = makeTransformerEngineTensor(rs_output); + cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); + te::CommOverlapBase::atomic_gemm_overlap_rs(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, + workspace_, grad, accumulate, use_split_accumulator, + gemm_overlap, rs_out_, stream_main); +} // CommOverlap::split_overlap_rs + +/* +** Split FPROP GEMM + ReduceScatter +*/ +void CommOverlap::split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + te::DType A_type, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, int64_t B_fp8_tensor, + te::DType B_type, bool transb, at::Tensor D, at::Tensor D_scale, + te::DType D_type, at::Tensor D_amax, at::Tensor bias, + te::DType bias_type, at::Tensor pre_gelu_out, bool grad, + at::Tensor workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, bool gemm_overlap, + at::Tensor rs_output) { + MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, + B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, + pre_gelu_out, workspace) + + auto rs_out_ = makeTransformerEngineTensor(rs_output); + cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); + te::CommOverlapBase::split_overlap_rs(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, + workspace_, grad, accumulate, use_split_accumulator, + gemm_overlap, rs_out_, stream_main); +} // CommOverlap::split_overlap_rs + +/* +** Helper function to copy input to _ubuf +*/ +void CommOverlap::copy_input_to_ubuf(torch::Tensor input, int comm_type) { + char *ubuf_ptr = reinterpret_cast(_ubuf.dptr()); + te::CommOverlapType _comm_type = static_cast(comm_type); + if (_comm_type == te::CommOverlapType::AG) { + if ((input.numel() * _tp_size) != (int64_t)_ubuf.numel() || + input.element_size() != (int64_t)_ubuf.element_size()) { + NVTE_ERROR("input and ubuf size do not match!"); + } + ubuf_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); + } else { + if (input.numel() != (int64_t)_ubuf.numel() || + input.element_size() != (int64_t)_ubuf.element_size()) { + NVTE_ERROR("input and ubuf size do not match!"); + } + } + + at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); + NVTE_CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0)); + NVTE_CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input.data_ptr(), input.numel() * input.element_size(), + cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm)); +} + +torch::Tensor CommOverlap::get_ubuf_output(int comm_type) { + char *ubuf_wt_ptr = reinterpret_cast(_ubuf.dptr()); + te::CommOverlapType _comm_type = static_cast(comm_type); + if (_comm_type != te::CommOverlapType::AG && _comm_type != te::CommOverlapType::RS) + NVTE_ERROR("Invalid comm_type"); + if (_comm_type == te::CommOverlapType::RS) + ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); + int output_c_dim0 = + (_comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; + int output_c_dim1 = _ubuf.size(1); + return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, + torch::device(torch::kCUDA).dtype(GetATenDType(_ubuf.dtype()))); +} + +/*************************************************************************************************** + * CommOverlapP2P + **************************************************************************************************/ + +CommOverlapP2P::CommOverlapP2P(const std::vector &buffer_shape, at::ScalarType buffer_dtype, + CommOverlapHelper *helper, int tp_size, + te::CommOverlapType comm_type, int num_max_streams, + int comm_cga_size, int num_comm_sm, bool set_sm_margin, + bool atomic_gemm, bool use_ce, bool aggregate) + : te::CommOverlapP2PBase( + buffer_shape, GetTransformerEngineDType(buffer_dtype), helper->myrank, helper->numranks, + helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, tp_size, + std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), + std::bind(&CommOverlapHelper::ub_barrier, helper, _1), comm_type, num_max_streams, + comm_cga_size, num_comm_sm, set_sm_margin, use_ce, atomic_gemm, aggregate) { + // Even though we never use these PyTorch tensor wrappers directly, they're still necessary to + // for PyTorch to factor externally allocated memory into its memory pool and garbage collection + // threshold calculation. + _ubuf_torch = torch::from_blob( + _ubuf.dptr(), {static_cast(_ubuf.size(0)), static_cast(_ubuf.size(1))}, + at::device(torch::kCUDA).dtype(buffer_dtype)); + if (_atomic_gemm) { + _ubuf_counter = torch::from_blob(_counter.dptr(), {static_cast(_num_splits * 2)}, + at::device(torch::kCUDA).dtype(torch::kInt32)); + } +} + +/* +** Split AllGather + AtomicGEMM using P2P communication +** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is +*needed to have AG outputs +** in each rank to be in the contiguous memory space after all ring exchange +*phases. +*/ +void CommOverlapP2P::atomic_gemm_overlap_ag( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, + at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, + at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, + te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { + MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, + B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, + pre_gelu_out, workspace) + + auto B_copy_ = makeTransformerEngineTensor(B_copy); + cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); + te::CommOverlapP2PBase::atomic_gemm_overlap_ag(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, + workspace_, grad, accumulate, + use_split_accumulator, B_copy_, stream_main); +} // atomic_gemm_overlap_ag + +/* +** Split AllGather + GEMM using P2P communication +** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is +*needed to have AG outputs +** in each rank to be in the contiguous memory space after all ring exchange +*phases. +*/ +void CommOverlapP2P::split_overlap_ag( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, + at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, + at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, + te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { + MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, + B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, + pre_gelu_out, workspace) + + auto B_copy_ = makeTransformerEngineTensor(B_copy); + cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); + te::CommOverlapP2PBase::split_overlap_ag(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, + workspace_, grad, accumulate, use_split_accumulator, + B_copy_, stream_main); +} // split_overlap_ag + +/* +** Split ReduceScatter + GEMM using P2P communication +*/ +void CommOverlapP2P::atomic_gemm_overlap_rs( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, + at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, + at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, + te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor rs_output) { + MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, + B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, + pre_gelu_out, workspace) + + auto rs_out_ = makeTransformerEngineTensor(rs_output); + cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); + te::CommOverlapP2PBase::atomic_gemm_overlap_rs(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, + workspace_, grad, accumulate, + use_split_accumulator, rs_out_, stream_main); +} + +/* +** Split ReduceScatter + GEMM using P2P communication +*/ +void CommOverlapP2P::split_overlap_rs( + at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, + at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, + at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, + te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor rs_output) { + MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, + B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, + pre_gelu_out, workspace) + + auto rs_out_ = makeTransformerEngineTensor(rs_output); + cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); + te::CommOverlapP2PBase::split_overlap_rs(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, + workspace_, grad, accumulate, use_split_accumulator, + rs_out_, stream_main); +} + +/* +** Copy input to _ubufs[0] +*/ +void CommOverlapP2P::copy_input_to_ubuf(torch::Tensor input, bool chunk) { + at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); + if (chunk) { + // Copy input to the target ubuf chunk by rank offset + if (input.numel() != (int64_t)_ubufs[0].numel() || + input.element_size() != (int64_t)_ubufs[0].element_size()) { + NVTE_ERROR("input and ubuf size do not match!"); + } + NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].dptr(), input.data_ptr(), + input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, + (cudaStream_t)stream_main)); + } else { + if (input.numel() != (int64_t)_ubuf.numel() || + input.element_size() != (int64_t)_ubuf.element_size()) { + NVTE_ERROR("input and ubuf size do not match!"); + } + NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.dptr(), input.data_ptr(), + input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, + (cudaStream_t)stream_main)); + } +} + +torch::Tensor CommOverlapP2P::get_ubuf_output(int comm_type) { + char *ubuf_wt_ptr = reinterpret_cast(_ubuf.dptr()); + te::CommOverlapType _comm_type = static_cast(comm_type); + if (_comm_type != te::CommOverlapType::AG && _comm_type != te::CommOverlapType::RS) + NVTE_ERROR("Invalid comm_type"); + if (_comm_type == te::CommOverlapType::RS) + ubuf_wt_ptr += _ubuf.numel() / _tp_size * _self_chunk_id * _ubuf.element_size(); + int output_c_dim0 = + (_comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; + int output_c_dim1 = _ubuf.size(1); + return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf_torch.options()); +} diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cu b/transformer_engine/pytorch/csrc/extensions/gemm.cpp similarity index 100% rename from transformer_engine/pytorch/csrc/extensions/gemm.cu rename to transformer_engine/pytorch/csrc/extensions/gemm.cpp diff --git a/transformer_engine/pytorch/csrc/extensions/misc.cu b/transformer_engine/pytorch/csrc/extensions/misc.cpp similarity index 100% rename from transformer_engine/pytorch/csrc/extensions/misc.cu rename to transformer_engine/pytorch/csrc/extensions/misc.cpp diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu index 09b53a8976..7d49a0848b 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu @@ -179,7 +179,7 @@ struct AdamFunctorMaster { } }; -template +template struct AdamFunctor { __device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl, // NOLINT(*) @@ -199,10 +199,10 @@ struct AdamFunctor { index_t chunk_idx = tl.block_to_chunk[blockIdx.x]; index_t n = tl.sizes[tensor_loc]; - T *g = reinterpret_cast(tl.addresses[0][tensor_loc]); + GRAD_T *g = reinterpret_cast(tl.addresses[0][tensor_loc]); g += chunk_idx * chunk_size; - T *p = reinterpret_cast(tl.addresses[1][tensor_loc]); + PARAM_T *p = reinterpret_cast(tl.addresses[1][tensor_loc]); p += chunk_idx * chunk_size; FULL_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); @@ -223,10 +223,10 @@ struct AdamFunctor { for (int ii = 0; ii < ILP; ii++) { int i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) { - r_g[ii] = g[i]; - r_p[ii] = p[i]; - r_m[ii] = m[i]; - r_v[ii] = v[i]; + r_g[ii] = static_cast(g[i]); + r_p[ii] = static_cast(p[i]); + r_m[ii] = static_cast(m[i]); + r_v[ii] = static_cast(v[i]); } else { r_g[ii] = MATH_T(0); r_p[ii] = MATH_T(0); @@ -259,9 +259,9 @@ struct AdamFunctor { for (int ii = 0; ii < ILP; ii++) { int i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) { - p[i] = r_p[ii]; - m[i] = r_m[ii]; - v[i] = r_v[ii]; + p[i] = static_cast(r_p[ii]); + m[i] = static_cast(r_m[ii]); + v[i] = static_cast(r_v[ii]); } } } @@ -491,6 +491,7 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, } } + const auto g_in_type = tensor_lists[0][0].scalar_type(); const auto p_in_type = tensor_lists[1][0].scalar_type(); auto tl_size = tensor_lists.size(); @@ -503,13 +504,15 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, // Assume single type across p,g,m1,m2 now DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( p_in_type, 0, "adam", - multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, - AdamFunctor(), beta1, beta2, - bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, - weight_decay);) + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + g_in_type, 1, "adam", + multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, + tensor_lists, + AdamFunctor(), beta1, + beta2, bias_correction1, bias_correction2, epsilon, lr, + (adamMode_t)mode, weight_decay);)); } else { // g, p, m, v, p_master - const auto g_in_type = tensor_lists[0][0].scalar_type(); DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( p_in_type, 0, "adam", DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( @@ -525,12 +528,13 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, // Assume single type across p,g,m1,m2 now DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( p_in_type, 0, "adam", - multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - AdamFunctor(), beta1, beta2, - bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, - weight_decay);) + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + g_in_type, 1, "adam", + multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + AdamFunctor(), beta1, + beta2, bias_correction1, bias_correction2, epsilon, lr, + (adamMode_t)mode, weight_decay);)); } else { - const auto g_in_type = tensor_lists[0][0].scalar_type(); DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( p_in_type, 0, "adam", DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cu b/transformer_engine/pytorch/csrc/extensions/normalization.cpp similarity index 100% rename from transformer_engine/pytorch/csrc/extensions/normalization.cu rename to transformer_engine/pytorch/csrc/extensions/normalization.cpp diff --git a/transformer_engine/pytorch/csrc/extensions/padding.cu b/transformer_engine/pytorch/csrc/extensions/padding.cpp similarity index 100% rename from transformer_engine/pytorch/csrc/extensions/padding.cu rename to transformer_engine/pytorch/csrc/extensions/padding.cpp diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 7bd5a2d8c8..8856553c54 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -4,12 +4,15 @@ * See LICENSE for license information. ************************************************************************/ -#include +#include +#include -#include "../comm_gemm_overlap.h" #include "../extensions.h" +#include "common/util/pybind_helper.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) + // Permutation functions m.def("moe_permute_fwd", moe_permute_fwd); m.def("moe_permute_bwd", moe_permute_bwd); @@ -88,6 +91,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("grad_output"), py::arg("gelu_input"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); + m.def("fused_dswiglu_cast_transpose", &fused_dswiglu_cast_transpose, + "Fused SwiGLU backward + FP8 cast + FP8 transpose", + py::call_guard(), py::arg("grad_output"), py::arg("input"), + py::arg("grad_input"), py::arg("grad_input_transpose"), py::arg("scale"), py::arg("amax"), + py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0, + py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose, "Fused Multi-tensor Cast + Transpose", py::call_guard()); m.def("fused_multi_cast_transpose_alloc", &fused_multi_cast_transpose_alloc, @@ -226,90 +235,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv) .def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history); - m.def("device_supports_multicast", &ubuf::device_supports_multicast, - py::call_guard()); - - m.def("ubuf_built_with_mpi", &ubuf::ubuf_built_with_mpi, - py::call_guard()); - - py::class_(m, "UbufBootstrapCallbacks") - .def(py::init<>(), py::call_guard()) - .def(py::init(), - py::call_guard()); - - py::enum_(m, "UbufOverlapAlgo") - .value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG) - .value("BULK_OVERLAP_RS", ubuf::UBOverlapAlgo::BULK_OVERLAP_RS) - .value("SPLIT_PIPELINED_RS", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS) - .value("SPLIT_PIPELINED_RS_P2P", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS_P2P) - .value("SPLIT_PIPELINED_AG_P2P", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_AG_P2P) - .value("ATOMIC_GEMM_RS", ubuf::UBOverlapAlgo::ATOMIC_GEMM_RS) - .value("ATOMIC_GEMM_AG_P2P", ubuf::UBOverlapAlgo::ATOMIC_GEMM_AG_P2P) - .value("ATOMIC_GEMM_RS_P2P", ubuf::UBOverlapAlgo::ATOMIC_GEMM_RS_P2P); - - // Note: Can't release GIL in constructor since it may bootstrap - // communicator with Python functions (e.g. PyTorch distributed - // communication) - py::class_(m, "UbufCommOverlap") - .def(py::init(), - py::call_guard()) - .def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap, - py::call_guard()) - .def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs, - py::call_guard()) - .def("set_ubuf_scale_inv", &ubuf::UbufCommOverlap::set_ubuf_scale_inv, - py::call_guard()) - .def("atomic_gemm_overlap_rs", &ubuf::UbufCommOverlap::atomic_gemm_overlap_rs, - py::call_guard()) - .def("is_fp8_ubuf", &ubuf::UbufCommOverlap::is_fp8_ubuf, - py::call_guard()) - .def("copy_input_to_ubuf", &ubuf::UbufCommOverlap::copy_input_to_ubuf, - py::call_guard()) - .def("get_ubuf_output", &ubuf::UbufCommOverlap::get_ubuf_output, - py::call_guard()) - .def("is_atomic_gemm", &ubuf::UbufCommOverlap::is_atomic_gemm, - py::call_guard()) - .def("is_p2p_overlap", &ubuf::UbufCommOverlap::is_p2p_overlap, - py::call_guard()); - - // Note: Can't release GIL in constructor since it may bootstrap - // communicator with Python functions (e.g. PyTorch distributed - // communication) - py::class_(m, "UbufP2PCommOverlap") - .def(py::init(), - py::call_guard()) - .def("split_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_ag, - py::call_guard()) - .def("split_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_rs, - py::call_guard()) - .def("atomic_gemm_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_ag, - py::call_guard()) - .def("atomic_gemm_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_rs, - py::call_guard()) - .def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf, - py::call_guard()) - .def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output, - py::call_guard()) - .def("is_fp8_ubuf", &ubuf::UbufP2PCommOverlap::is_fp8_ubuf, - py::call_guard()) - .def("is_atomic_gemm", &ubuf::UbufP2PCommOverlap::is_atomic_gemm, - py::call_guard()) - .def("is_p2p_overlap", &ubuf::UbufP2PCommOverlap::is_p2p_overlap, - py::call_guard()) - .def("set_ubuf_scale_inv", &ubuf::UbufP2PCommOverlap::set_ubuf_scale_inv, - py::call_guard()); - - py::enum_(m, "DType", py::module_local()) - .value("kByte", transformer_engine::DType::kByte) - .value("kInt32", transformer_engine::DType::kInt32) - .value("kFloat32", transformer_engine::DType::kFloat32) - .value("kFloat16", transformer_engine::DType::kFloat16) - .value("kBFloat16", transformer_engine::DType::kBFloat16) - .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) - .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); - py::enum_(m, "FP8FwdTensors") .value("GEMM1_INPUT", transformer_engine::FP8FwdTensors::GEMM1_INPUT) .value("GEMM1_WEIGHT", transformer_engine::FP8FwdTensors::GEMM1_WEIGHT) @@ -329,41 +254,61 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .value("GRAD_OUTPUT3", transformer_engine::FP8BwdTensors::GRAD_OUTPUT3) .value("GRAD_INPUT3", transformer_engine::FP8BwdTensors::GRAD_INPUT3); - py::enum_(m, "NVTE_Bias_Type") - .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) - .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) - .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) - .value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI); - - py::enum_(m, "NVTE_Mask_Type") - .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) - .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) - .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) - .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) - .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) - .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", - NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); + py::class_(m, "CommOverlapHelper") + .def(py::init<>(), py::call_guard()) + .def(py::init, + std::optional>(), + py::call_guard(), py::arg("world_group"), + py::arg("intra_node_group") = py::none(), py::arg("inter_node_group") = py::none()); - py::enum_(m, "NVTE_QKV_Layout") - .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) - .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) - .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) - .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) - .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) - .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) - .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) - .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) - .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) - .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) - .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) - .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) - .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) - .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) - .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); + py::class_(m, "CommOverlap") + .def(py::init &, at::ScalarType, CommOverlapHelper *, int, int, int, + int, int, bool, bool>(), + py::call_guard(), py::arg("buffer_shape"), + py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), + py::arg("num_splits") = 3, py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, + py::arg("comm_cga_size") = 2, py::arg("num_comm_sm") = 16, + py::arg("set_sm_margin") = true, py::arg("atomic_gemm") = false) + .def("bulk_overlap", &CommOverlap::bulk_overlap, py::call_guard()) + .def("split_overlap_rs", &CommOverlap::split_overlap_rs, + py::call_guard()) + .def("atomic_gemm_overlap_rs", &CommOverlap::atomic_gemm_overlap_rs, + py::call_guard()) + .def("copy_input_to_ubuf", &CommOverlap::copy_input_to_ubuf, + py::call_guard()) + .def("get_ubuf_output", &CommOverlap::get_ubuf_output, + py::call_guard()) + .def("set_ubuf_scale_inv", &CommOverlap::set_ubuf_scale_inv, + py::call_guard()) + .def("is_atomic_gemm", &CommOverlap::is_atomic_gemm, py::call_guard()) + .def("is_p2p_overlap", &CommOverlap::is_p2p_overlap, py::call_guard()) + .def("is_fp8_ubuf", &CommOverlap::is_fp8_ubuf, py::call_guard()); - py::enum_(m, "NVTE_Fused_Attn_Backend") - .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) - .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) - .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) - .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); + py::class_(m, "CommOverlapP2P") + .def(py::init &, at::ScalarType, CommOverlapHelper *, int, + transformer_engine::CommOverlapType, int, int, int, bool, bool, bool, bool>(), + py::call_guard(), py::arg("buffer_shape"), + py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), py::arg("comm_type"), + py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, py::arg("comm_cga_size") = 1, + py::arg("num_comm_sm") = 1, py::arg("set_sm_margin") = false, + py::arg("atomic_gemm") = false, py::arg("use_ce") = true, py::arg("aggregate") = false) + .def("split_overlap_ag_p2p", &CommOverlapP2P::split_overlap_ag, + py::call_guard()) + .def("split_overlap_rs_p2p", &CommOverlapP2P::split_overlap_rs, + py::call_guard()) + .def("atomic_gemm_overlap_ag_p2p", &CommOverlapP2P::atomic_gemm_overlap_ag, + py::call_guard()) + .def("atomic_gemm_overlap_rs_p2p", &CommOverlapP2P::atomic_gemm_overlap_rs, + py::call_guard()) + .def("copy_input_to_ubuf", &CommOverlapP2P::copy_input_to_ubuf, + py::call_guard()) + .def("get_ubuf_output", &CommOverlapP2P::get_ubuf_output, + py::call_guard()) + .def("set_ubuf_scale_inv", &CommOverlapP2P::set_ubuf_scale_inv, + py::call_guard()) + .def("is_fp8_ubuf", &CommOverlapP2P::is_fp8_ubuf, py::call_guard()) + .def("is_atomic_gemm", &CommOverlapP2P::is_atomic_gemm, + py::call_guard()) + .def("is_p2p_overlap", &CommOverlapP2P::is_p2p_overlap, + py::call_guard()); } diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cu b/transformer_engine/pytorch/csrc/extensions/recipe.cpp similarity index 100% rename from transformer_engine/pytorch/csrc/extensions/recipe.cu rename to transformer_engine/pytorch/csrc/extensions/recipe.cpp diff --git a/transformer_engine/pytorch/csrc/extensions/softmax.cu b/transformer_engine/pytorch/csrc/extensions/softmax.cpp similarity index 100% rename from transformer_engine/pytorch/csrc/extensions/softmax.cu rename to transformer_engine/pytorch/csrc/extensions/softmax.cpp diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cu b/transformer_engine/pytorch/csrc/extensions/transpose.cpp similarity index 82% rename from transformer_engine/pytorch/csrc/extensions/transpose.cu rename to transformer_engine/pytorch/csrc/extensions/transpose.cpp index 56f6b56769..f373cdf83a 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cu +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -196,6 +196,75 @@ std::vector fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, return {grad_bias, dgelu, dgelu_transpose}; } +void fused_dswiglu_cast_transpose(at::Tensor grad_output, at::Tensor input, at::Tensor grad_input, + at::Tensor grad_input_transpose, at::Tensor scale, + at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype, int scale_offset, + int amax_offset, int scale_inv_offset) { + using namespace transformer_engine; + + // Tensor dimensions + auto outer_dim = [](const at::Tensor& tensor) -> size_t { + return tensor.numel() / tensor.size(-1); + }; + const auto M = outer_dim(grad_output); + const auto N = static_cast(grad_output.size(-1)); + + // Check tensor dims + NVTE_CHECK(grad_output.dim() == 2, "Expected grad output tensor to have 2 dims, but found ", + grad_output.dim()); + NVTE_CHECK(input.dim() == 2, "Expected input tensor to have 2 dims, but found ", input.dim()); + NVTE_CHECK(outer_dim(input) == M, "Expected input tensor to have outer dimension of ", M, + ", but found ", outer_dim(input)); + NVTE_CHECK(input.size(-1) == 2 * N, "Expected input tensor to have inner dimension of ", 2 * N, + ", but found ", input.size(-1)); + NVTE_CHECK(grad_input.dim() == 2, "Expected grad input tensor to have 2 dims, but found ", + grad_input.dim()); + NVTE_CHECK(outer_dim(grad_input) == M, "Expected grad input tensor to have outer dimension of ", + M, ", but found ", outer_dim(grad_input)); + NVTE_CHECK(grad_input.size(-1) == 2 * N, "Expected grad input tensor to have inner dimension of ", + 2 * N, ", but found ", grad_input.size(-1)); + NVTE_CHECK(grad_input_transpose.dim() == 2, + "Expected grad input transpose tensor to have 2 dims, but found ", + grad_input_transpose.dim()); + NVTE_CHECK(grad_input_transpose.size(0) == 2 * N, + "Expected grad input tensor to have outer dimension of ", 2 * N, ", but found ", + grad_input_transpose.size(0)); + NVTE_CHECK(grad_input_transpose.size(1) == M, + "Expected grad input tensor to have outer dimension of ", M, ", but found ", + grad_input_transpose.size(1)); + + // Check tensor format + NVTE_CHECK(grad_output.is_contiguous(), "Expected grad output tensor to be contiguous"); + NVTE_CHECK(input.is_contiguous(), "Expected input tensor to be contiguous"); + NVTE_CHECK(grad_input.is_contiguous(), "Expected grad input tensor to be contiguous"); + NVTE_CHECK(grad_input_transpose.is_contiguous(), + "Expected grad input transpose tensor to be contiguous"); + NVTE_CHECK(grad_output.scalar_type() == input.scalar_type(), + "Expected grad output tensor and input tensor to have same dtype"); + NVTE_CHECK(grad_input.scalar_type() == at::ScalarType::Byte, + "Expected grad input tensor to be uint8 buffer"); + NVTE_CHECK(grad_input_transpose.scalar_type() == at::ScalarType::Byte, + "Expected grad input transpose tensor to be uint8 buffer"); + + // Get pointers for FP8 scale, amax, scale-inverse + void* scale_dptr = getDataPtr(scale, scale_offset); + void* amax_dptr = getDataPtr(amax, amax_offset); + void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); + + // Construct Transformer Engine tensors + auto dy_cu = makeTransformerEngineTensor(grad_output); + auto x_cu = makeTransformerEngineTensor(input); + auto dx_cu = makeTransformerEngineTensor(grad_input.data_ptr(), {M, 2 * N}, otype, amax_dptr, + scale_dptr, scale_inv_dptr); + auto dx_t_cu = makeTransformerEngineTensor(grad_input_transpose.data_ptr(), {2 * N, M}, otype, + amax_dptr, scale_dptr, scale_inv_dptr); + + // Launch kernel + nvte_dswiglu_cast_transpose(dy_cu.data(), x_cu.data(), dx_cu.data(), dx_t_cu.data(), + at::cuda::getCurrentCUDAStream()); +} + void fused_multi_cast_transpose_base(std::vector input_list, std::vector scale_dptr_list, std::vector cast_output_list, diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 76679eb064..2a909dabc6 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -109,8 +109,6 @@ def reset(cls) -> None: cls.fp8_available = None cls.reason_for_no_fp8 = "" cls.autocast_arguments = {} - cls.autocast_to_fp8_params = {} - cls.fp8_param_to_autocast = {} cls.skip_fp8_weight_update_tensor = None @classmethod @@ -156,28 +154,25 @@ def get_buffer_info(cls) -> str: def get_key_in_buffer( cls, forward: bool, - fp8_weights: bool, fp8_recipe: DelayedScaling, fp8_group: dist_group_type, ) -> str: """Returns a key into the global FP8 buffers.""" autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) fwd_bwd_key = cls.get_fwd_bwd_key(forward) - return f"{fwd_bwd_key}_{fp8_weights}_{autocast_key}" + return f"{fwd_bwd_key}_{autocast_key}" @classmethod - def split_key_in_buffer(cls, key: str) -> Tuple[bool, bool, str]: + def split_key_in_buffer(cls, key: str) -> Tuple[bool, str]: """Splits buffer key into relevant parts.""" - forward, fp8_weights, autocast_key = key.split("_", 2) + forward, autocast_key = key.split("_", 1) forward = forward == "forward" - fp8_weights = fp8_weights == "True" - return forward, fp8_weights, autocast_key + return forward, autocast_key @classmethod def add_fp8_tensors_to_global_buffer( cls, fp8_meta: Dict[str, Any], - fp8_weights: Optional[List[torch.Tensor]] = None, ) -> None: """ The amax reduction process happens completely outside the FP8 modules. @@ -202,33 +197,12 @@ def add_fp8_tensors_to_global_buffer( fp8_meta[index_in_buffer] = [] for forward in (True, False): - # This algorithm creates a two-way map with `autocast_to_fp8_params` and - # `fp8_param_to_autocast`. This is used for keeping track of FP8 weights - # in an autocasted region and cross reference them in `float8_tensor.py` - # to perform the forward amax reduction. fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) if fp8_meta_tensor_key not in fp8_meta: # Handles non-parameter FP8 modules, e.g. DPA. continue - if forward and fp8_weights is not None: - autocast_key = cls.get_unique_autocast_key( - fp8_meta["recipe"], fp8_meta["fp8_group"] - ) - fp8_weight_set = {id(w._data) for w in fp8_weights} - if autocast_key not in cls.autocast_to_fp8_params: - cls.autocast_to_fp8_params[autocast_key] = fp8_weight_set - else: - cls.autocast_to_fp8_params[autocast_key] = cls.autocast_to_fp8_params[ - autocast_key - ].union(fp8_weight_set) - # Identify correct autocast key for a given param. - for w in fp8_weight_set: - cls.fp8_param_to_autocast[w] = autocast_key - - key = cls.get_key_in_buffer( - forward, fp8_weights is not None, fp8_meta["recipe"], fp8_meta["fp8_group"] - ) + key = cls.get_key_in_buffer(forward, fp8_meta["recipe"], fp8_meta["fp8_group"]) if key not in cls.global_amax_buffer: cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] @@ -277,7 +251,9 @@ def is_first_fp8_module(cls): @classmethod def get_fp8_recipe(cls) -> DelayedScaling: """Return the fp8 recipe""" - return cls.FP8_RECIPE + if cls.FP8_RECIPE is not None: + return cls.FP8_RECIPE + return get_default_fp8_recipe() @classmethod def get_fp8_group(cls) -> Union[dist_group_type, None]: @@ -325,20 +301,13 @@ def reduce_tensor_across_group_op_max(tensor: torch.Tensor, group: dist_group_ty def reduce_and_update_fp8_tensors( cls, forward: bool = True, - fp8_weights: bool = False, ) -> None: """Concatenate, reduce, and split amaxes in the global buffer.""" for buffer_key, amax_buffer in cls.global_amax_buffer.items(): # Check for forward or backward reduction. - fwd_update, fp8_weights_update, autocast_key = cls.split_key_in_buffer(buffer_key) + fwd_update, autocast_key = cls.split_key_in_buffer(buffer_key) if fwd_update != forward: continue - # Only skip a forward update when `fp8_weights` is explicitly set to `True` - # (inside optimizer) and the current key is not an `fp8_weight_update` key. - # For other cases, we need to reduce because of activation tensors. - # TODO(ksivaman) consider separate weight and activation fp8_tensors. - if fwd_update and fp8_weights and not fp8_weights_update: - continue if len(amax_buffer) == 0: continue @@ -432,7 +401,7 @@ def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None: # FP8 weight modules are reduced at the end of the optimizer # step after the weight amax is populated. if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled(): - cls.reduce_and_update_fp8_tensors(forward=True, fp8_weights=False) + cls.reduce_and_update_fp8_tensors(forward=True) @classmethod def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index cba71e1326..c47b792a95 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Functions for CUDA Graphs support in FP8""" +from collections.abc import Iterable from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union import torch @@ -18,7 +19,7 @@ ) from .distributed import get_all_rng_states, graph_safe_rng_available from .module.base import TransformerEngineBaseModule - +from .ops.op import BasicOperation __all__ = ["make_graphed_callables"] @@ -464,7 +465,7 @@ def new_fwd(*user_args, **user_kwargs): m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( - m.fp8_meta, fp8_weights=m._get_fp8_params() + m.fp8_meta, ) return graphed(*user_args, **user_kwargs) return orig_fwd(*user_args, **user_kwargs) @@ -486,28 +487,46 @@ def new_fwd(*user_args, **user_kwargs): return tuple(ret) -def save_fp8_tensors(modules, amax_history_len): +def save_fp8_tensors( + modules: Iterable[torch.nn.Module], + fp8_recipe: DelayedScaling, +) -> List[Any]: """ Returns the FP8 tensors for all modules with adjusted amax history sizes. """ - saved_fp8_meta_tensors = [] + fp8_tensors = [] for module in modules: for m in module.modules(): + module_tensors = None if isinstance(m, TransformerEngineBaseModule): if m.primary_weights_in_fp8: - m.adjust_amax_history_length(amax_history_len) - saved_fp8_meta_tensors.append(m.get_fp8_meta_tensors()) - return saved_fp8_meta_tensors - - -def restore_fp8_tensors(modules, fp8_tensors): + m.adjust_amax_history_length(fp8_recipe.amax_history_len) + module_tensors = m.get_fp8_meta_tensors() + elif isinstance(m, BasicOperation): + m.pre_forward(fp8_enabled=True, fp8_recipe=fp8_recipe) + module_tensors = m._save_fp8_metas() + fp8_tensors.append(module_tensors) + return fp8_tensors + + +def restore_fp8_tensors( + modules: Iterable[torch.nn.Module], + fp8_tensors: List[Any], +) -> None: """Restore FP8 tensors.""" for module in modules: for m in module.modules(): + module_tensors = fp8_tensors.pop(0) if isinstance(m, TransformerEngineBaseModule): - m.reset_fp8_meta_tensors(fp8_tensors.pop(0)) - assert len(fp8_tensors) == 0, "TE internal error." + m.reset_fp8_meta_tensors(module_tensors) + elif isinstance(m, BasicOperation): + m._load_fp8_metas(module_tensors) + if len(fp8_tensors) != 0: + raise RuntimeError( + f"Got FP8 state for {len(fp8_tensors)} more modules than expected. " + "There is probably a discrepancy with `save_fp8_tensors`." + ) def make_graphed_callables( @@ -580,7 +599,7 @@ def make_graphed_callables( modules = (modules,) # Store FP8 tensors to reset later. - saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe.amax_history_len) + saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe=fp8_recipe) # FP8 wrapper. def wrap_autocast(block): diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index bc4a06b4cb..3a15242c3a 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -87,9 +87,55 @@ def initialize_ub( ub_cfgs: Optional[dict] = None, bootstrap_backend: Union[str, torch.distributed.Backend] = None, ) -> None: - """Initialize communicators for TP comm overlap using userbuffers.""" + r""" + Initialize the Userbuffers communicator for overlapping tensor-parallel communications with + GEMM compute in te.Linear, te.LayerNormLinear and te.LayerNormMLP modules. + + Parameters + ---------- + shape : list + shape of the communication buffer, typically set to be the same as the global shape of + the input tensor to a te.TransformerLayer forward pass, with the sequence and batch + dimensions collapsed together -- i.e.: `(sequence_length * batch_size, hidden_size)` + tp_size : int + number of GPUs in the tensor-parallel process group + use_fp8 : bool = False + allocate the communication buffer for FP8 GEMM inputs/outputs + dtype : torch.dtype = torch.bfloat16 + non-FP8 data type of the communication buffer when `use_fp8 = False` + ub_cfgs: dict = None + Configuration dictionary with the structure + ``` + { + : { + "method": <"ring_exchange" or "pipeline">, + "is_reduce_scatter": bool, + "num_sm": int, + "cga_size": int, + "set_sm_margin": bool, + "num_splits": int, + "aggregate": bool, + "atomic_gemm": bool, + "use_ce": bool, + "fp8_buf": bool, + } + } + ``` + for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad", + "proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad", + "fc2_fprop", "fc2_dgrad"]`. + bootstrap_backend : str = None + `torch.distributed` communication backend for the all-gather, broadcast and + barrier collectives during Userbuffers initialization. Not all backends are + valid for every cluster configuration and distributed launch method even if + they are available in PyTorch. When left unset, the initialization prefers + to use the MPI backend, falling back first on Gloo and then NCCL if MPI is + not available. Setting `NVTE_UB_WITH_MPI=1` when building TE overrides this + option and always initializes Userbuffers with direct MPI calls in C++, + which also requires `MPI_HOME=/path/to/mpi/root` to be set at compile time. + """ if not tex.device_supports_multicast(): - assert bool(os.getenv("UB_SKIPMC", "0")), ( + assert bool(int(os.getenv("UB_SKIPMC", "0"))), ( "CUDA device, driver and/or toolkit version does not support comm+GEMM overlap with " + "CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead." ) @@ -99,50 +145,52 @@ def initialize_ub( _ub_communicators = {} if tex.ubuf_built_with_mpi(): - # Userbuffers will ignore all these values when it is built with MPI, so these are just - # placeholders based on an assumption that tp_size covers all devices in a physical node. + # We're bootstrapping with direct calls to MPI in Userbuffers code so we need to force + # an MPI_Init() here by creating a new MPI process group... assert torch.distributed.is_mpi_available() - mpi_group = torch.distributed.new_group(backend="mpi") - world_rank = torch.distributed.get_rank(mpi_group) - world_size = torch.distributed.get_world_size(mpi_group) - local_rank = world_rank % tp_size - local_size = tp_size - self_node_idx = world_rank // tp_size - num_nodes = world_size // tp_size - ub_callbacks = tex.UbufBootstrapCallbacks() + _ = torch.distributed.new_group(backend="mpi") + helper = tex.CommOverlapHelper() else: + # Bootstrapping with torch.distributed API, so check backend and construct + # intra/inter-node process groups... assert ( torch.distributed.is_initialized() ), "torch.distributed must be initialized before Userbuffers" if bootstrap_backend is None: bootstrap_backend = "nccl" - if torch.distributed.is_gloo_available(): - bootstrap_backend = "gloo" - elif torch.distributed.is_mpi_available(): + if torch.distributed.is_mpi_available(): bootstrap_backend = "mpi" + elif torch.distributed.is_gloo_available(): + bootstrap_backend = "gloo" else: - assert bootstrap_backend in ["gloo", "mpi", "nccl"] + assert bootstrap_backend in [ + "gloo", + "mpi", + "nccl", + ], "Invalid torch.distributed backend for bootstrapping Userbuffers!" + assert torch.distributed.is_backend_available(bootstrap_backend), ( + f"PyTorch must be compiled with '{bootstrap_backend}' support in order to " + f"bootstrap Userbuffers with '{bootstrap_backend}' collectives." + ) world_group = torch.distributed.new_group(backend=bootstrap_backend) world_rank = torch.distributed.get_rank(world_group) world_size = torch.distributed.get_world_size(world_group) - # Construct an intra-node communicator based on global ranks that share the same hostname - # NOTE: If the user specified a valid network interface for NCCL or GLOO, use the host - # address on that interface instead of the hostname. This can help avoid issues when - # different hosts have the same hostname on Kubernetes clusters. - hostname = socket.gethostname() + # We have single-node NVLink so we can color based on physical node hostnames. + # NOTE: Prefer a network interface defined via the NVTE_UB_SOCKET_IFNAME variable, and + # otherwise fall back on NCCL_SOCKET_IFNAME or GLOO_SOCKET_IFNAME depending on + # the chosen bootstrap backend. + mydomain = socket.gethostname() ifname = os.getenv( - "NVTE_UB_SOCKET_IFNAME", - os.getenv("NCCL_SOCKET_IFNAME", os.getenv("GLOO_SOCKET_IFNAME")), + "NVTE_UB_SOCKET_IFNAME", os.getenv(f"{bootstrap_backend.upper()}_SOCKET_IFNAME") ) - if ifname is not None: # Make sure the ifname found in the environment is a valid network interface if ifname in [name for _, name in socket.if_nameindex()]: s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) try: - hostname = socket.inet_ntoa( + mydomain = socket.inet_ntoa( fcntl.ioctl( s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8")) )[20:24] @@ -154,57 +202,64 @@ def initialize_ub( else: ifname_warning = ( f"'{ifname}' is not a valid network interface! `te.initialize_ub()` will" - " attempt to " - + "detect ranks on the same node by matching 'socket.gethostname()', which is " - + "known to fail on virtual clusters like Kubernetes. If Userbuffers " - + "initialization fails, please set the 'NVTE_UB_SOCKET_IFNAME' variable in " - + "your environment to the correct network interface." + + " attempt to detect ranks on the same node by matching " + + "'socket.gethostname()', which is known to fail on virtual clusters like " + + "Kubernetes. If Userbuffers initialization fails, please set the " + + "'NVTE_UB_SOCKET_IFNAME' variable in your environment to the correct network " + + "interface." ) warnings.warn(ifname_warning, UserWarning) - hostnames = [None for _ in range(world_size)] - torch.distributed.all_gather_object(hostnames, hostname, world_group) - unique_hosts = [] - for host in hostnames: - if host not in unique_hosts: - unique_hosts.append(host) - num_nodes = len(unique_hosts) - - if num_nodes > 1: - ranks_per_node_list = [[] for _ in range(num_nodes)] - self_node_idx = -1 - for i, host in enumerate(hostnames): - node_idx = unique_hosts.index(host) - ranks_per_node_list[node_idx].append(i) - if host == hostname: - self_node_idx = node_idx - assert self_node_idx >= 0, "Internal TE error!" - - intra_node_group, _ = torch.distributed.new_subgroups_by_enumeration( - ranks_per_node_list, backend=bootstrap_backend + # Allgather the domain colors across ranks and reduce to a list of unique domains + domain_per_rank_list = [None for _ in range(world_size)] + torch.distributed.all_gather_object(domain_per_rank_list, mydomain, world_group) + unique_domains = [] + for domain in domain_per_rank_list: + if domain not in unique_domains: + unique_domains.append(domain) + num_domains = len(unique_domains) + + if num_domains > 1: + # DP/TP model replicated on multiple NVLink domains + ranks_per_domain_list = [[] for _ in range(num_domains)] + mydomain_idx = -1 + for i, domain in enumerate(domain_per_rank_list): + domain_idx = unique_domains.index(domain) + ranks_per_domain_list[domain_idx].append(i) + if domain == mydomain: + mydomain_idx = domain_idx + assert mydomain_idx >= 0, "Internal TE error!" + + intra_domain_group, _ = torch.distributed.new_subgroups_by_enumeration( + ranks_per_domain_list, backend=bootstrap_backend + ) + local_rank = torch.distributed.get_rank(intra_domain_group) + intra_domain_ranks = torch.distributed.get_process_group_ranks(intra_domain_group) + + inter_domain_group, _ = torch.distributed.new_subgroups_by_enumeration( + [list(ranks) for ranks in zip(*ranks_per_domain_list)], + backend=bootstrap_backend, ) - local_rank = torch.distributed.get_rank(intra_node_group) - local_size = torch.distributed.get_world_size(intra_node_group) - intra_node_ranks = torch.distributed.get_process_group_ranks(intra_node_group) + + helper = tex.CommOverlapHelper(world_group, intra_domain_group, inter_domain_group) else: - self_node_idx = 0 - intra_node_group = world_group + # TP model on single NVLink domain, no replication, no data-parallelism + mydomain_idx = 0 local_rank = world_rank - local_size = world_size - intra_node_ranks = list(range(world_size)) + intra_domain_ranks = list(range(world_size)) + + helper = tex.CommOverlapHelper(world_group) if world_rank == 0: - print(f"!!! [UB] Number of physical nodes: {num_nodes}\n", end="", flush=True) + print(f"!!! [UB] Number of NVLink domains: {num_domains}\n", end="", flush=True) if local_rank == 0: print( - f"!!! [UB] Global ranks on node {self_node_idx}: {intra_node_ranks}\n", + f"!!! [UB] Global ranks on domain {mydomain_idx}: {intra_domain_ranks}\n", end="", flush=True, ) - ub_callbacks = tex.UbufBootstrapCallbacks(world_group, intra_node_group) - # Increase the workspace by the number of maximum concurrent streams global _cublas_workspace _cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS) @@ -303,46 +358,34 @@ def add_ub( if atomic_gemm and method == "ring_exchange": assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message - sample_buffer = torch.empty( - shape, dtype=torch.uint8 if (use_fp8 and fp8_buf) else dtype, device="cuda" - ) + buffer_dtype = torch.uint8 if (use_fp8 and fp8_buf) else dtype if method == "ring_exchange": - ub_obj = tex.UbufP2PCommOverlap( - sample_buffer, # Sample userbuffer - world_rank, # World rank - world_size, # World size - local_rank, # Rank within the node - local_size, # Number of ranks/GPUs per node - self_node_idx, # Node ID - num_nodes, # Number of nodes + ub_obj = tex.CommOverlapP2P( + shape, # Communication buffer shape + buffer_dtype, # Communication buffer data type + helper, # Helper for torch.distributed callbacks during bootstrapping tp_size, # Tensor-parallel group size (may be different than local_size) - num_sm, # Number of communication SMs - cga_size, # CGA cluster size - set_sm_margin, # Set SM margin - aggregate, # Aggregate 2X GEMM chunks - _NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams - is_reduce_scatter, # Overlap with reduce scatter - atomic_gemm, # Use a single GEMM with atomic-counters - use_ce, # Use copy engine for P2P communications - ub_callbacks, + tex.CommOverlapType.RS if is_reduce_scatter else tex.CommOverlapType.AG, + num_max_streams=_NUM_MAX_UB_STREAMS, + comm_cga_size=cga_size, + num_comm_sm=num_sm, + set_sm_margin=set_sm_margin, + atomic_gemm=atomic_gemm, + use_ce=use_ce, + aggregate=aggregate, ) else: - ub_obj = tex.UbufCommOverlap( - sample_buffer, # Sample userbuffer - world_rank, # World rank - world_size, # World size - local_rank, # Rank within the node - local_size, # Number of ranks/GPUs per node - self_node_idx, # Node ID - num_nodes, # Number of nodes + ub_obj = tex.CommOverlap( + shape, # Communication buffer shape + buffer_dtype, # Communication buffer data type + helper, # Helper for torch.distributed callbacks during bootstrapping tp_size, # Tensor-parallel group size (may be different than local_size) - num_sm, # Number of communication SMs - cga_size, # CGA cluster size - num_splits, # Number of communication splits - set_sm_margin, # Set SM margin - _NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams - atomic_gemm, # Use a single GEMM with atomic-counters - ub_callbacks, + num_splits=num_splits, + num_max_streams=_NUM_MAX_UB_STREAMS, + comm_cga_size=cga_size, + num_comm_sm=num_sm, + set_sm_margin=set_sm_margin, + atomic_gemm=atomic_gemm, ) _ub_communicators[name] = ub_obj @@ -719,9 +762,7 @@ def prepare_forward( ) if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing(): - FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( - self.fp8_meta, fp8_weights=self._get_fp8_params() - ) + FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta) # Activation recomputation is used and this is the first forward phase. if self.fp8 and self.training and is_fp8_activation_recompute_enabled(): diff --git a/transformer_engine/pytorch/module/layernorm.py b/transformer_engine/pytorch/module/layernorm.py index 0c439ac417..b42079d299 100644 --- a/transformer_engine/pytorch/module/layernorm.py +++ b/transformer_engine/pytorch/module/layernorm.py @@ -3,158 +3,107 @@ # See LICENSE for license information. """LayerNorm API""" -import os import warnings -from typing import Union, Tuple, Optional +from typing import Iterable, Optional, Union import torch -from torch.nn.parameter import Parameter -from torch.nn import init -import transformer_engine_torch as tex -from ..cpp_extensions import ( - layernorm_fwd_inf, -) -from ..jit import no_torch_dynamo -from ..utils import cast_if_needed +from transformer_engine.pytorch.ops import LayerNorm as _LayerNormOp __all__ = ["LayerNorm"] -class _LayerNorm(torch.autograd.Function): - """functional LayerNorm""" - - @staticmethod - def forward( - ctx, - inp: torch.Tensor, - ln_weight: torch.Tensor, - ln_bias: torch.Tensor, - eps: float, - fwd_ln_sm_margin: int, - bwd_ln_sm_margin: int, - inf_ln_sm_margin: int, - zero_centered_gamma: bool, - is_grad_enabled: bool, - activation_dtype: torch.dtype, - ) -> torch.Tensor: - # pylint: disable=missing-function-docstring - # Make sure input dimensions are compatible - in_features = ln_weight.numel() - assert inp.is_cuda, "TransformerEngine needs CUDA." - assert inp.shape[-1] == in_features, "LayerNorm not possible" - inputmat = inp.view((-1, in_features)) - - # Cast for native AMP - inputmat = cast_if_needed(inputmat, activation_dtype) - ln_weight = cast_if_needed(ln_weight, activation_dtype) - ln_bias = cast_if_needed(ln_bias, activation_dtype) - - if is_grad_enabled: - ln_out, mu, rsigma = tex.layernorm_fwd( - inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma - ) - ctx.save_for_backward(inputmat, ln_weight, mu, rsigma) - ctx.inp_shape = inp.shape - ctx.bwd_ln_sm_margin = bwd_ln_sm_margin - ctx.zero_centered_gamma = zero_centered_gamma - else: - ln_out, mu, rsigma = ( - layernorm_fwd_inf( - inputmat, ln_weight, ln_bias, eps, inf_ln_sm_margin, zero_centered_gamma - ), - None, - None, - ) - return ln_out.view_as(inp) - - @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: - # pylint: disable=missing-function-docstring - inputmat, ln_weight, mu, rsigma = ctx.saved_tensors - grad_output = grad_output.contiguous() - d_ln_out = grad_output.view(inputmat.shape) - dxmat, dgamma, dbeta = tex.layernorm_bwd( - d_ln_out, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma - ) - return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None, None, None, None, None - +class LayerNorm(_LayerNormOp): + r"""Layer Normalization -class LayerNorm(torch.nn.Module): - r""" Applies Layer Normalization over a mini-batch of inputs as described in the paper `Layer Normalization `__ .. math:: - y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta + y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta - :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of - size :attr:`hidden_size` + :math:`\gamma` and :math:`\beta` are learnable affine transform + parameters that match the inner-most dimensions of the input + tensor. Parameters ---------- - hidden_size : int - size of each input sample. + normalized_shape: int or iterable of int + Inner dimensions of input tensor eps : float, default = 1e-5 - a value added to the denominator of layer normalization for numerical stability. - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - params_dtype : torch.dtype, default = `torch.get_default_dtype()` - it controls the type used to allocate the initial parameters. Useful when - the model is trained with lower precision and the original FP32 parameters - would not fit in GPU memory. + A value added to the denominator of layer normalization for + numerical stability + device: torch.device, default = default CUDA device + Tensor device + dtype: torch.dtype, default = default dtype + Tensor datatype zero_centered_gamma : bool, default = 'False' - if set to 'True', gamma parameter in LayerNorm is initialized to 0 and - the LayerNorm formula changes to - - .. math:: - y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * - (1 + \gamma) + \beta - device : Union[torch.device, str], default = "cuda" - The device on which the parameters of the model will be allocated. It is the user's - responsibility to ensure all parameters are moved to the GPU before running the - forward pass. + If `True`, the :math:`\gamma` parameter is initialized to zero + and the calculation changes to + + .. math:: + y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta + + sm_margin: int or dict, default = 0 + Number of SMs to exclude when launching CUDA kernels. This + helps overlap with other kernels, e.g. communication kernels. + For more fine-grained control, provide a dict with the SM + margin at each compute stage ("forward", "backward", + "inference"). + + Legacy + ------ + sequence_parallel: bool + Set a bool attr named `sequence_parallel` in the parameters. + This is custom logic for Megatron-LM integration. + """ def __init__( self, - hidden_size: int, + normalized_shape: Union[Iterable[int], int, None] = None, eps: float = 1e-5, - sequence_parallel: bool = False, - params_dtype: Optional[torch.dtype] = None, + sequence_parallel: Optional[bool] = None, # legacy + params_dtype: Optional[torch.dtype] = None, # deprecated zero_centered_gamma: bool = False, - device: Union[torch.device, str] = "cuda", + hidden_size: Optional[int] = None, # deprecated + **kwargs, ) -> None: - super().__init__() - params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype - self.eps = eps - self.zero_centered_gamma = zero_centered_gamma - self.weight = Parameter( - torch.empty( - hidden_size, - device=device, - dtype=params_dtype, + + # Handle deprecated options + if normalized_shape is None: + if hidden_size is None: + raise RuntimeError( + "Neither `normalized_shape` nor `hidden_size` (deprecated) args are provided" + ) + warnings.warn( + "`hidden_size` arg has been renamed to `normalized_shape` " + "for compatibility with `torch.nn.LayerNorm`.", + DeprecationWarning, + stacklevel=2, ) - ) - self.bias = Parameter( - torch.empty( - hidden_size, - device=device, - dtype=params_dtype, + normalized_shape = hidden_size + elif hidden_size is not None: + raise RuntimeError( + "Both `normalized_shape` and `hidden_size` (deprecated) args are provided" ) + if params_dtype is not None: + if "dtype" in kwargs: + raise RuntimeError( + "Both `dtype` and `params_dtype` (deprecated) kwargs are provided" + ) + kwargs["dtype"] = params_dtype + + # Initialize layer norm operation + super().__init__( + normalized_shape, + eps=eps, + zero_centered_gamma=zero_centered_gamma, + **kwargs, ) - self.sequence_parallel = sequence_parallel - self.activation_dtype: Optional[torch.dtype] = None - self.reset_parameters(defer_init=device == "meta") - - # These many SMs are subtracted from the total SM count when calling forward - # and backward LayerNorm C APIs. These envvars can be used to prevent the LN - # kernels from using all SMs in the device. This is useful for cases such as - # communication overlap with LN. - self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) - self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) - self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) + # Flag for sequence parallelism (custom Megatron-LM integration) + self.sequence_parallel: Optional[bool] = sequence_parallel def reset_layer_norm_parameters(self) -> None: """Init LN params""" @@ -164,64 +113,62 @@ def reset_layer_norm_parameters(self) -> None: DeprecationWarning, stacklevel=2, ) - if not self.zero_centered_gamma: - init.ones_(self.weight) - else: - init.zeros_(self.weight) - init.zeros_(self.bias) + self.reset_parameters() - def reset_parameters(self, defer_init=False) -> None: + def reset_parameters(self, defer_init: Optional[bool] = None) -> None: """Init LayerNorm parameters""" - if defer_init: - return - - if self.weight.device == torch.device("meta"): - self.weight = torch.nn.Parameter(torch.empty_like(self.weight, device="cuda")) - setattr(self.weight, "sequence_parallel", self.sequence_parallel) - init.constant_(self.weight, float(not self.zero_centered_gamma)) - - if self.bias.device == torch.device("meta"): - self.bias = torch.nn.Parameter(torch.empty_like(self.bias, device="cuda")) - setattr(self.bias, "sequence_parallel", self.sequence_parallel) - init.zeros_(self.bias) - - @no_torch_dynamo() - def forward(self, inp: torch.Tensor) -> torch.Tensor: - # pylint: disable=missing-function-docstring - - # Set the activation type for AMP. - # Note: This will soon be deprecated with - # https://github.com/NVIDIA/TransformerEngine/pull/1033 - if torch.is_autocast_enabled(): - self.activation_dtype = torch.get_autocast_gpu_dtype() - elif self.activation_dtype != inp.dtype: - dtype = inp.dtype - for name, param in self.named_parameters(): - if param is not None: - assert dtype == param.dtype, ( - "Data types for parameters must match when outside of autocasted region. " - f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" - ) - self.activation_dtype = dtype - - if torch.is_grad_enabled(): - fwd_fn = _LayerNorm.apply - args = [] - else: - fwd_fn = _LayerNorm.forward - args = [None] - - args += ( - inp, - self.weight, - self.bias, - self.eps, - self.fwd_ln_sm_margin, - self.bwd_ln_sm_margin, - self.inf_ln_sm_margin, - self.zero_centered_gamma, - torch.is_grad_enabled(), - self.activation_dtype, - ) - return fwd_fn(*args) + # Check whether to defer init (deprecated) + if defer_init is not None: + warnings.warn( + "defer_init argument to reset_parameters function is deprecated. Set device to" + ' "meta" instead.', + DeprecationWarning, + stacklevel=2, + ) + if defer_init: + return + + # Reset parameters + super().reset_parameters() + + # Set flag for sequence parallelism (custom Megatron-LM integration) + if getattr(self, "sequence_parallel", None) is not None: + self.weight.sequence_parallel = self.sequence_parallel + self.bias.sequence_parallel = self.sequence_parallel + + @property + def fwd_ln_sm_margin(self) -> int: + """Shim for backward compatibility""" + warnings.warn("fwd_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2) + return self._sm_margins["forward"] + + @fwd_ln_sm_margin.setter + def fwd_ln_sm_margin(self, val: int) -> None: + """Shim for backward compatibility""" + warnings.warn("fwd_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2) + self._sm_margins["forward"] = val + + @property + def bwd_ln_sm_margin(self) -> int: + """Shim for backward compatibility""" + warnings.warn("bwd_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2) + return self._sm_margins["backward"] + + @bwd_ln_sm_margin.setter + def bwd_ln_sm_margin(self, val: int) -> None: + """Shim for backward compatibility""" + warnings.warn("bwd_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2) + self._sm_margins["backward"] = val + + @property + def inf_ln_sm_margin(self) -> int: + """Shim for backward compatibility""" + warnings.warn("inf_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2) + return self._sm_margins["inference"] + + @inf_ln_sm_margin.setter + def inf_ln_sm_margin(self, val: int) -> None: + """Shim for backward compatibility""" + warnings.warn("inf_ln_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2) + self._sm_margins["inference"] = val diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 97006a0671..fbf1b97704 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -161,9 +161,9 @@ def forward( if not return_layernorm_output: ln_out = torch.empty_like(ln_out) if ub_obj_lnout.is_atomic_gemm(): - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P + ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P elif parallel_mode == "column" and sequence_parallel: ln_out_gathered = True ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) @@ -293,7 +293,7 @@ def forward( get_workspace(), bias=bias, use_bias=use_bias, - ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None, + ub_algo=tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None, ub=ub_obj_lnout if ub_overlap_ag else None, extra_output_tensor=ln_out if ub_overlap_ag else None, ) @@ -485,7 +485,7 @@ def backward( rs_out = None if ctx.ub_bulk_dgrad: - ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG + ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_AG ub_obj = ub_obj_lnout elif ctx.ub_overlap_rs_dgrad: dim_size = list(grad_output.size()) @@ -496,14 +496,14 @@ def backward( ) if ub_obj_dgrad.is_p2p_overlap(): if ctx.fp8 and ub_obj_dgrad.is_atomic_gemm(): - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P else: if ctx.fp8 and ub_obj_dgrad.is_atomic_gemm(): - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS + ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS ub_obj = ub_obj_dgrad else: ub_algo = None @@ -616,7 +616,7 @@ def backward( out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, use_split_accumulator=_2X_ACC_WGRAD, ub_algo=( - tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None + tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None ), ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, extra_output_tensor=extra_output_tensor, @@ -640,7 +640,7 @@ def backward( accumulate=accumulate_wgrad_into_param_main_grad, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, ub_algo=( - tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None + tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None ), ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, extra_output_tensor=extra_output_tensor, @@ -658,7 +658,7 @@ def backward( use_bias=ctx.use_bias, accumulate=accumulate_wgrad_into_param_main_grad, out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, - ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, + ub_algo=tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, ) clear_tensor_data(ln_out_total) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 966924a85c..64e8c9ce36 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -180,9 +180,9 @@ def forward( ln_out_total = ub_obj_lnout.get_ubuf_output(1) ln_out = torch.empty_like(ln_out) if ub_obj_lnout.is_atomic_gemm(): - ub_algo_ag = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P + ub_algo_ag = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P else: - ub_algo_ag = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P + ub_algo_ag = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P elif set_parallel_mode and sequence_parallel: ln_out_gathered = True ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) @@ -298,14 +298,14 @@ def forward( rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) if ub_obj_fc2out.is_p2p_overlap(): if ub_obj_fc2out.is_atomic_gemm(): - ub_algo_rs = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + ub_algo_rs = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P else: - ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + ub_algo_rs = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P else: if ub_obj_fc2out.is_atomic_gemm(): - ub_algo_rs = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS + ub_algo_rs = tex.CommOverlapAlgo.ATOMIC_GEMM_RS else: - ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + ub_algo_rs = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS if ub_obj_fc2out.is_fp8_ubuf(): fc2_out_index = tex.FP8FwdTensors.GEMM2_OUTPUT @@ -369,7 +369,7 @@ def forward( bias=fc1_bias, use_bias=(not bias_gelu_nvfusion) and use_fc1_bias, gelu=not bias_gelu_nvfusion and (activation == "gelu"), - ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None, + ub_algo=tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None, ub=ub_obj_lnout if ub_overlap_ag else None, extra_output_tensor=ln_out if ub_overlap_ag else None, ) @@ -410,9 +410,9 @@ def forward( dim_size[1] = fc2_weight.size(0) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) if ub_obj_fc2out.is_p2p_overlap(): - ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + ub_algo_rs = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P else: - ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + ub_algo_rs = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS else: dim_size = list(gelu_out.size()) dim_size[1] = fc2_weight.size(0) @@ -615,9 +615,9 @@ def backward( dim_size[0] = dim_size[0] * tp_world_size ctx.ub_obj_gradout = get_ub("fc2_dgrad") if ctx.ub_obj_gradout.is_atomic_gemm(): - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P + ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P ctx.use_bias = ctx.use_fc2_bias # For grad_output_preprocess ( @@ -788,7 +788,7 @@ def backward( # Set UB algo and UB obj for fc1_dgrad bulk/pipelined overlap rs_out = None if ctx.ub_bulk_dgrad: - ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG + ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_AG ub_obj = ub_obj_lnout elif ctx.ub_overlap_rs_dgrad: dim_size = list(dgelu.size()) @@ -797,14 +797,14 @@ def backward( rs_out = torch.empty(dim_size, dtype=ctx.activation_dtype, device=dgelu.device) if ub_obj_dgrad.is_p2p_overlap(): if ub_obj_dgrad.is_atomic_gemm(): - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P else: if ub_obj_dgrad.is_atomic_gemm(): - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS + ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS ub_obj = ub_obj_dgrad else: ub_algo = None @@ -842,7 +842,7 @@ def backward( grad=True, gelu_input=fc1_out, ub_algo=( - tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ctx.ub_overlap_ag else None + tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ctx.ub_overlap_ag else None ), ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, ) @@ -892,7 +892,7 @@ def backward( # Set UB algo and UB obj for fc1_dgrad bulk/pipelined overlap if ctx.ub_bulk_dgrad: - ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG + ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_AG ub_obj = ub_obj_lnout elif ctx.ub_overlap_rs_dgrad: dim_size = list(dgelu.size()) @@ -900,9 +900,9 @@ def backward( dim_size[1] = fc1_weight.size(1) rs_out = torch.empty(dim_size, dtype=ctx.activation_dtype, device=dgelu.device) if ub_obj_dgrad.is_p2p_overlap(): - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS ub_obj = ub_obj_dgrad else: ub_algo = None @@ -967,7 +967,7 @@ def backward( out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, use_split_accumulator=_2X_ACC_WGRAD, ub_algo=( - tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None + tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None ), ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, extra_output_tensor=extra_output_tensor, @@ -991,7 +991,7 @@ def backward( accumulate=accumulate_wgrad_into_param_main_grad, out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, ub_algo=( - tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None + tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None ), ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, extra_output_tensor=extra_output_tensor, @@ -1009,7 +1009,7 @@ def backward( use_bias=not ctx.bias_gelu_nvfusion, accumulate=accumulate_wgrad_into_param_main_grad, out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, - ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, + ub_algo=tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, ) clear_tensor_data(ln_out_total, dgelu) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 403eef091f..1fed467210 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -190,14 +190,14 @@ def forward( rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) if ub_obj_projout.is_p2p_overlap(): if ub_obj_projout.is_atomic_gemm(): - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P else: if ub_obj_projout.is_atomic_gemm(): - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS + ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS if ub_obj_projout.is_fp8_ubuf(): proj_out_index = tex.FP8FwdTensors.GEMM1_OUTPUT meta_tensor = fp8_meta["scaling_fwd"] @@ -269,9 +269,9 @@ def forward( dim_size[1] = out_features rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) if ub_obj_projout.is_p2p_overlap(): - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS else: dim_size = list(inputmat_total.size()) dim_size[1] = out_features @@ -407,9 +407,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dim_size[0] = dim_size[0] * tp_world_size ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") if ctx.ub_obj_gradout.is_atomic_gemm(): - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P + ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P + ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P ( grad_output, @@ -496,7 +496,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], layout="NN", grad=True, ub_algo=( - tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P + tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ctx.ub_overlap_ag else None ), diff --git a/transformer_engine/pytorch/module/rmsnorm.py b/transformer_engine/pytorch/module/rmsnorm.py index fc6ec5746f..bd7db1f775 100644 --- a/transformer_engine/pytorch/module/rmsnorm.py +++ b/transformer_engine/pytorch/module/rmsnorm.py @@ -3,221 +3,175 @@ # See LICENSE for license information. """RMSNorm API""" -import os import warnings -from typing import Union, Tuple, Optional +from typing import Iterable, Optional, Union import torch -from torch.nn.parameter import Parameter -from torch.nn import init - -from .. import cpp_extensions as tex -from ..jit import no_torch_dynamo -from ..utils import cast_if_needed +from transformer_engine.pytorch.ops import RMSNorm as _RMSNormOp __all__ = ["RMSNorm"] -class _RMSNorm(torch.autograd.Function): - """functional RMSNorm""" - - @staticmethod - def forward( - ctx, - inp: torch.Tensor, - rmsnorm_weight: torch.Tensor, - eps: float, - fwd_rmsnorm_sm_margin: int, - bwd_rmsnorm_sm_margin: int, - inf_rmsnorm_sm_margin: int, - zero_centered_gamma: bool, - is_grad_enabled: bool, - activation_dtype: torch.dtype, - ) -> torch.Tensor: - # pylint: disable=missing-function-docstring - # Make sure input dimensions are compatible - in_features = rmsnorm_weight.numel() - assert inp.is_cuda, "TransformerEngine needs CUDA." - assert inp.shape[-1] == in_features, "RMSNorm not possible" - inputmat = inp.view((-1, in_features)) - - # Cast for native AMP - inputmat = cast_if_needed(inputmat, activation_dtype) - rmsnorm_weight = cast_if_needed(rmsnorm_weight, activation_dtype) - - if is_grad_enabled: - rmsnorm_out, rsigma = tex.rmsnorm_fwd( - inputmat, rmsnorm_weight, eps, fwd_rmsnorm_sm_margin, zero_centered_gamma - ) - ctx.save_for_backward(inputmat, rmsnorm_weight, rsigma) - ctx.inp_shape = inp.shape - ctx.bwd_rmsnorm_sm_margin = bwd_rmsnorm_sm_margin - ctx.zero_centered_gamma = zero_centered_gamma - else: - rmsnorm_out = tex.rmsnorm_fwd_inf( - inputmat, rmsnorm_weight, eps, inf_rmsnorm_sm_margin, zero_centered_gamma - ) - return rmsnorm_out.view_as(inp) - - @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: - # pylint: disable=missing-function-docstring - inputmat, rmsnorm_weight, rsigma = ctx.saved_tensors - grad_output = grad_output.contiguous() - d_rmsnorm_out = grad_output.view(inputmat.shape) - dxmat, dgamma = tex.rmsnorm_bwd( - d_rmsnorm_out, - inputmat, - rsigma, - rmsnorm_weight, - ctx.bwd_rmsnorm_sm_margin, - ctx.zero_centered_gamma, - ) - return ( - dxmat.view(ctx.inp_shape), - dgamma, - None, - None, - None, - None, - None, - None, - None, - ) - +class RMSNorm(_RMSNormOp): + r"""Root Mean Square Layer Normalization -class RMSNorm(torch.nn.Module): - r""" - Applies Root Mean Square Layer Normalization over a mini-batch of inputs as described in - the paper `Root Mean Square Layer Normalization `__ + Applies Root Mean Square Layer Normalization over a mini-batch of + inputs as described in the paper + `Root Mean Square Layer Normalization `__ .. math:: - y = \frac{x}{RMS_\varepsilon(x)} * \gamma + y = \frac{x}{\text{RMS}_\varepsilon(x)} * \gamma where .. math:: - RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^nx_i^2 + \varepsilon} + \text{RMS}_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^n x_i^2 + \varepsilon} - :math:`\gamma` is a learnable affine transform parameter of size :attr:`hidden_size` + :math:`\gamma` is a learnable affine transform parameter that + matches the inner-most dimensions of the input tensor. Parameters ---------- - hidden_size : int - size of each input sample. + normalized_shape: int or iterable of int + Inner dimensions of input tensor eps : float, default = 1e-5 - a value added to the denominator of layer normalization for numerical stability. - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - params_dtype : torch.dtype, default = `torch.get_default_dtype()` - it controls the type used to allocate the initial parameters. Useful when - the model is trained with lower precision and the original FP32 parameters - would not fit in GPU memory. + A value added to the denominator for numerical stability + device: torch.device, default = default CUDA device + Tensor device + dtype: torch.dtype, default = default dtype + Tensor datatype zero_centered_gamma : bool, default = 'False' - if set to 'True', gamma parameter in RMSNorm is initialized to 0 and - the RMSNorm formula changes to - - .. math:: - y = \frac{x}{RMS_\varepsilon(x)} * (1 + \gamma) - device : Union[torch.device, str], default = "cuda" - The device on which the parameters of the model will be allocated. It is the user's - responsibility to ensure all parameters are moved to the GPU before running the - forward pass. + If `True`, the :math:`\gamma` parameter is initialized to zero + and the calculation changes to + + .. math:: + y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + + sm_margin: int, default = 0 + Number of SMs to exclude when launching CUDA kernels. This + helps overlap with other kernels, e.g. communication kernels. + For more fine-grained control, provide a dict with the SM + margin at each compute stage ("forward", "backward", + "inference"). + + Legacy + ------ + sequence_parallel: bool + Set a bool attr named `sequence_parallel` in the parameters. + This is custom logic for Megatron-LM integration. + """ def __init__( self, - hidden_size: int, + normalized_shape: Union[Iterable[int], int, None] = None, eps: float = 1e-5, - sequence_parallel: bool = False, - params_dtype: Optional[torch.dtype] = None, + sequence_parallel: Optional[bool] = None, # legacy + params_dtype: Optional[torch.dtype] = None, # deprecated zero_centered_gamma: bool = False, - device: Union[torch.device, str] = "cuda", + hidden_size: Optional[int] = None, # deprecated + **kwargs, ) -> None: - super().__init__() - params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype - self.eps = eps - self.zero_centered_gamma = zero_centered_gamma - self.weight = Parameter( - torch.empty( - hidden_size, - device=device, - dtype=params_dtype, + + # Handle deprecated options + if normalized_shape is None: + if hidden_size is None: + raise RuntimeError( + "Neither `normalized_shape` nor `hidden_size` (deprecated) args are provided" + ) + warnings.warn( + "`hidden_size` arg has been renamed to `normalized_shape` " + "for compatibility with `torch.nn.LayerNorm`.", + DeprecationWarning, + stacklevel=2, + ) + normalized_shape = hidden_size + elif hidden_size is not None: + raise RuntimeError( + "Both `normalized_shape` and `hidden_size` (deprecated) args are provided" ) + if params_dtype is not None: + if "dtype" in kwargs: + raise RuntimeError( + "Both `dtype` and `params_dtype` (deprecated) kwargs are provided" + ) + kwargs["dtype"] = params_dtype + + # Initialize RMSNorm operation + super().__init__( + normalized_shape, + eps=eps, + zero_centered_gamma=zero_centered_gamma, + **kwargs, ) - self.sequence_parallel = sequence_parallel - self.activation_dtype: Optional[torch.dtype] = None - self.reset_parameters(defer_init=device == "meta") - - # These many SMs are subtracted from the total SM count when calling forward - # and backward RMSNorm C APIs. These envvars can be used to prevent the LN - # kernels from using all SMs in the device. This is useful for cases such as - # communication overlap with RMSNorm. - self.fwd_rmsnorm_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) - self.bwd_rmsnorm_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) - self.inf_rmsnorm_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) + # Flag for sequence parallelism (custom Megatron-LM integration) + self.sequence_parallel: Optional[bool] = sequence_parallel def reset_rms_norm_parameters(self) -> None: - """Init RMSNorm params""" + """Deprecated""" warnings.warn( "This method is deprecated and will be removed in an upcoming release. " "Update your code to use RMSNorm.reset_parameters() instead.", DeprecationWarning, stacklevel=2, ) - if not self.zero_centered_gamma: - init.ones_(self.weight) - else: - init.zeros_(self.weight) - - def reset_parameters(self, defer_init=False) -> None: - """Reset RMSNorm parameters""" - if defer_init: - return - - if self.weight.device == torch.device("meta"): - self.weight = torch.nn.Parameter(torch.empty_like(self.weight, device="cuda")) - init.constant_(self.weight, float(not self.zero_centered_gamma)) - setattr(self.weight, "sequence_parallel", self.sequence_parallel) - - @no_torch_dynamo() - def forward(self, inp: torch.Tensor) -> torch.Tensor: - # pylint: disable=missing-function-docstring - - # Set the activation type for AMP. - # Note: This will soon be deprecated with - # https://github.com/NVIDIA/TransformerEngine/pull/1033 - if torch.is_autocast_enabled(): - self.activation_dtype = torch.get_autocast_gpu_dtype() - elif self.activation_dtype != inp.dtype: - dtype = inp.dtype - for name, param in self.named_parameters(): - if param is not None: - assert dtype == param.dtype, ( - "Data types for parameters must match when outside of autocasted region. " - f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" - ) - self.activation_dtype = dtype - - if torch.is_grad_enabled(): - fwd_fn = _RMSNorm.apply - args = [] - else: - fwd_fn = _RMSNorm.forward - args = [None] - - args += ( - inp, - self.weight, - self.eps, - self.fwd_rmsnorm_sm_margin, - self.bwd_rmsnorm_sm_margin, - self.inf_rmsnorm_sm_margin, - self.zero_centered_gamma, - torch.is_grad_enabled(), - self.activation_dtype, - ) - - return fwd_fn(*args) + self.reset_parameters() + + def reset_parameters(self, defer_init: Optional[bool] = None) -> None: + """Init RMSNorm parameters""" + + # Check whether to defer init (deprecated) + if defer_init is not None: + warnings.warn( + "defer_init argument to reset_parameters function is deprecated. Set device to" + ' "meta" instead.', + DeprecationWarning, + stacklevel=2, + ) + if defer_init: + return + + # Reset parameters + super().reset_parameters() + + # Flag for sequence parallelism (custom Megatron-LM integration) + if getattr(self, "sequence_parallel", None) is not None: + self.weight.sequence_parallel = self.sequence_parallel + + @property + def fwd_rmsnorm_sm_margin(self) -> int: + """Shim for backward compatibility""" + warnings.warn("fwd_rmsnorm_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2) + return self._sm_margins["forward"] + + @fwd_rmsnorm_sm_margin.setter + def fwd_rmsnorm_sm_margin(self, val: int) -> None: + """Shim for backward compatibility""" + warnings.warn("fwd_rmsnorm_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2) + self._sm_margins["forward"] = val + + @property + def bwd_rmsnorm_sm_margin(self) -> int: + """Shim for backward compatibility""" + warnings.warn("bwd_rmsnorm_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2) + return self._sm_margins["backward"] + + @bwd_rmsnorm_sm_margin.setter + def bwd_rmsnorm_sm_margin(self, val: int) -> None: + """Shim for backward compatibility""" + warnings.warn("bwd_rmsnorm_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2) + self._sm_margins["backward"] = val + + @property + def inf_rmsnorm_sm_margin(self) -> int: + """Shim for backward compatibility""" + warnings.warn("inf_rmsnorm_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2) + return self._sm_margins["inference"] + + @inf_rmsnorm_sm_margin.setter + def inf_rmsnorm_sm_margin(self, val: int) -> None: + """Shim for backward compatibility""" + warnings.warn("inf_rmsnorm_sm_margin attr is deprecated", DeprecationWarning, stacklevel=2) + self._sm_margins["inference"] = val diff --git a/transformer_engine/pytorch/ops/__init__.py b/transformer_engine/pytorch/ops/__init__.py index f437f877b4..f65433398e 100644 --- a/transformer_engine/pytorch/ops/__init__.py +++ b/transformer_engine/pytorch/ops/__init__.py @@ -8,17 +8,7 @@ """ -from transformer_engine.pytorch.ops.basic import ( - AddInPlace, - AllGather, - AllReduce, - BasicLinear, - Bias, - Identity, - MakeExtraOutput, - ReduceScatter, - Reshape, -) +from transformer_engine.pytorch.ops.basic import * from transformer_engine.pytorch.ops.linear import Linear from transformer_engine.pytorch.ops.op import FusibleOperation from transformer_engine.pytorch.ops.sequential import Sequential diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 12270d8340..b1654add98 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -9,6 +9,8 @@ import torch +from transformer_engine_torch import FP8TensorMeta +from ..fp8 import FP8GlobalStateManager from ..tensor import Float8Tensor from ..utils import ( canonicalize_device, # pylint: disable=unused-import @@ -56,6 +58,8 @@ def convert_tensor( if memory_format != torch.preserve_format and not data.is_contiguous( memory_format=memory_format ): + # Note: torch.Tensor.to ignores memory_format kwarg (see + # https://github.com/pytorch/pytorch/issues/132020). data = data.contiguous(memory_format=memory_format) return Float8Tensor.make_like( tensor, @@ -65,7 +69,14 @@ def convert_tensor( ) # Convert standard PyTorch tensor - return tensor.to(device=device, dtype=dtype, memory_format=memory_format) + tensor = tensor.to(device=device, dtype=dtype) + if memory_format != torch.preserve_format and not tensor.is_contiguous( + memory_format=memory_format + ): + # Note: torch.Tensor.to ignores memory_format kwarg (see + # https://github.com/pytorch/pytorch/issues/132020). + tensor = tensor.contiguous(memory_format=memory_format) + return tensor def reshape( @@ -114,3 +125,36 @@ def reshape( # Reshape standard PyTorch tensor return tensor.view(shape) + + +def maybe_autocast_dtype( + *, + device_type: str = "cuda", + default_dtype: Optional[torch.dtype] = None, +) -> torch.dtype: + """Get autocast dtype if enabled""" + if torch.is_autocast_enabled(device_type): + return torch.get_autocast_dtype(device_type) + return canonicalize_dtype(default_dtype) + + +def get_fp8_meta_from_fp8_tensor(tensor: Float8Tensor) -> tuple[FP8TensorMeta, int]: + """Get FP8TensorMeta object and index corresponding to Float8Tensor + + Constructs FP8TensorMeta if needed. + + """ + + # Check if tensor already has FP8 metadata + if tensor._fp8_meta is not None: + key = FP8GlobalStateManager.get_meta_tensor_key( + forward=tensor._fp8_meta_forward, + ) + return tensor._fp8_meta[key], tensor._fp8_meta_index + + # Create FP8TensorMeta class + fp8_meta = FP8TensorMeta() + fp8_meta.scale = tensor._scale_inv.reciprocal() + fp8_meta.amax_history = torch.empty(1, 1, dtype=torch.float32, device=tensor.device) + fp8_meta.scale_inv = tensor._scale_inv + return fp8_meta, 0 diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index 1003cc0337..d6f4940c58 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -4,12 +4,16 @@ """Single tensor operations supported by the operation fuser.""" +from .activation import GELU, ReLU, GEGLU, ReGLU, SwiGLU from .add_in_place import AddInPlace from .all_gather import AllGather from .all_reduce import AllReduce from .basic_linear import BasicLinear from .bias import Bias from .identity import Identity +from .layer_norm import LayerNorm from .make_extra_output import MakeExtraOutput +from .quantize import Quantize from .reduce_scatter import ReduceScatter from .reshape import Reshape +from .rmsnorm import RMSNorm diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py new file mode 100644 index 0000000000..a2e5a24a85 --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -0,0 +1,390 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operations for activation functions.""" + +from __future__ import annotations +import abc +from typing import Optional + +import torch + +import transformer_engine_torch +from ...constants import TE_DType +from ...cpp_extensions import ( + geglu as tex_geglu, + gelu as tex_gelu, + reglu as tex_reglu, + relu as tex_relu, + swiglu as tex_swiglu, + fp8_dswiglu_cast_transpose_fused, +) +from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype +from ...tensor import Float8Tensor, QuantizedTensor +from ...utils import clear_tensor_data, devices_match +from ..op import BasicOperation, OperationContext + + +class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): + r"""Apply activation function + + Activation functions are either element-wise unary functions or + variants of the gated linear unit (GLU). Recall that GLU is + computed by splitting the input tensor into chunks :math:`a` and + :math:`b` along the last dimension and computing + + .. math:: + \text{GLU}(a,b) = \sigma(a) * b + + .. warning:: + + Transformer Engine gated activations and PyTorch's GLU + activation follow opposite conventions for :math:`a` and + :math:`b`. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + """ + + @abc.abstractmethod + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + """Forward implementation + + Implementation from transformer_engine.pytorch.cpp_extensions. + + """ + + @abc.abstractmethod + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + """Backward implementation + + Implementation from transformer_engine_torch. + + """ + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op: Optional[BasicOperation] = None, + next_op: Optional[BasicOperation] = None, + ) -> torch.Tensor: + + # Compute dtype + dtype: torch.dtype + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + else: + dtype = input_.dtype + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise RuntimeError(f"Unsupported dtype ({dtype})") + + # Check input tensor + x = input_ + if isinstance(x, QuantizedTensor): + x = x.dequantize() + if x.device.type != "cuda": + x = x.cuda() + if x.dtype != dtype: + x = x.to(dtype=dtype) + if not x.is_contiguous(): + x = x.contiguous() + + # Check if FP8 is enabled + fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() + with_fp8_output = False + output_fp8_meta = None + output_dtype = TE_DType[dtype] + output_fp8_scale_inv = None + if fp8_enabled and next_op is not None and next_op.num_fp8_scales("input") > 0: + with_fp8_output = True + fp8_meta = next_op.get_fp8_meta("input") + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) + output_fp8_meta = fp8_meta[fp8_meta_key] + output_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + output_fp8_scale_inv = torch.empty([1], dtype=torch.float32, device=x.device) + + # Launch kernel + y = self._activation_forward_impl( + x, + output_fp8_meta, + 0, + output_dtype, + scale_inv=output_fp8_scale_inv, + ) + + # Check output tensor + if y.dim() != x.dim(): + y = y.reshape(list(x.shape[:-1]) + [-1]) + if with_fp8_output: + y = Float8Tensor( + data=y, + fp8_meta=output_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=output_dtype, + fp8_scale_inv=output_fp8_scale_inv, + dtype=dtype, + ) + + # Save state for backward pass + ctx.save_for_backward(x) + ctx.fp8_enabled = fp8_enabled + ctx.prev_op = prev_op + + return y + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + + # Saved tensors from forward pass + (x,) = ctx.saved_tensors + + # Check grad output tensor + dy = grad_output + if isinstance(dy, QuantizedTensor): + dy = dy.dequantize() + if not devices_match(dy.device, x.device) or dy.dtype != x.dtype: + dy = dy.to(device=x.device, dtype=x.dtype) + if not dy.is_contiguous(): + dy = dy.contiguous() + + # Launch kernel + dx = self._activation_backward_impl(dy, x, TE_DType[x.dtype]) + + # Check grad input tensor + if dx.size() != x.size(): + dx = dx.reshape(x.size()) + + # Clear input tensor if possible + if ctx.prev_op is not None: + clear_tensor_data(x) + + return dx, () + + +class GELU(_ActivationOperation): + r"""Gaussian Error Linear Unit + + This computes the "tanh" approximation to GELU: + + .. math:: + + \text{GELU}(x) \approx \frac{x}{2} \left( 1 + \tanh\left( 0.797x+0.036 x^3 \right) \right) + + See `Gaussian Error Linear Units (GELUs)`__. + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex_gelu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return transformer_engine_torch.dgelu(*args, **kwargs) + + +class ReLU(_ActivationOperation): + r"""Rectified linear unit + + .. math:: + + \text{ReLU}(x) = \max(x,0) + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex_relu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return transformer_engine_torch.drelu(*args, **kwargs) + + +class GEGLU(_ActivationOperation): + r"""Gaussian error gated linear unit + + The input tensor is split into chunks :math:`a` and :math:`b` + along the last dimension and the following is computed: + + .. math:: + + \text{GEGLU}(a,b) = \text{GELU}(a) * b + + where + + .. math:: + + \text{GELU}(x) \approx \frac{x}{2} \left( 1 + \tanh\left( 0.797x+0.036 x^3 \right) \right) + + .. warning:: + + Transformer Engine's gated activations and PyTorch's GLU + activation follow opposite conventions for :math:`a` and + :math:`b`. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + See `GLU Variants Improve Transformer`__. + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex_geglu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return transformer_engine_torch.dgeglu(*args, **kwargs) + + +class ReGLU(_ActivationOperation): + r"""Rectified gated linear unit + + The input tensor is split into chunks :math:`a` and :math:`b` + along the last dimension and the following is computed: + + .. math:: + + \text{ReGLU}(a,b) = \max(a,0) * b + + .. warning:: + + Transformer Engine's gated activations and PyTorch's GLU + activation follow opposite conventions for :math:`a` and + :math:`b`. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + See `GLU Variants Improve Transformer`__. + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex_reglu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return transformer_engine_torch.dreglu(*args, **kwargs) + + +class SwiGLU(_ActivationOperation): + r"""Swish gated linear unit + + The input tensor is split into chunks :math:`a` and :math:`b` + along the last dimension and the following is computed: + + .. math:: + + \text{GEGLU}(a,b) = \text{SiLU}(a) * b + + where + + .. math:: + + \text{SiLU}(x) = x \sigma(x) = \frac{x}{1+\exp(-x)} + + .. warning:: + + Transformer Engine's gated activations and PyTorch's GLU + activation follow opposite conventions for :math:`a` and + :math:`b`. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + The Sigmoid Linear Unit (SiLU) gating function is also known as + the swish function. See + `GLU Variants Improve Transformer`__ + and `Gaussian Error Linear Units (GELUs)`__. + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex_swiglu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return transformer_engine_torch.dswiglu(*args, **kwargs) + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + + # Saved tensors from forward pass + (x,) = ctx.saved_tensors + + # Tensor attributes + dtype = x.dtype + device = x.device + + # Check grad output tensor + dy = grad_output + if isinstance(dy, QuantizedTensor): + dy = dy.dequantize() + if not devices_match(dy.device, device) or dy.dtype != dtype: + dy = dy.to(device=device, dtype=dtype) + if not dy.is_contiguous(): + dy = dy.contiguous() + + # Check if FP8 is enabled + with_fp8_grad_input = False + grad_input_fp8_meta = None + grad_input_dtype = TE_DType[dtype] + grad_input_fp8_scale_inv = None + if ( + ctx.fp8_enabled + and ctx.prev_op is not None + and ctx.prev_op.num_fp8_scales("grad_output") > 0 + ): + with_fp8_grad_input = True + fp8_meta = ctx.prev_op.get_fp8_meta("grad_output") + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False) + grad_input_fp8_meta = fp8_meta[fp8_meta_key] + grad_input_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False) + grad_input_fp8_scale_inv = torch.empty([1], dtype=torch.float32, device=device) + + # Launch kernel + if with_fp8_grad_input: + # Fused with FP8 cast-transpose + input_dims = x.size() + flat_input_dims = [x.numel() // input_dims[-1], input_dims[-1]] + flat_output_dims = [flat_input_dims[0], flat_input_dims[1] // 2] + dx = torch.empty(input_dims, dtype=torch.uint8, device=device) + dx_t = torch.empty( + (flat_input_dims[1], flat_input_dims[0]), + dtype=torch.uint8, + device=device, + ) + fp8_dswiglu_cast_transpose_fused( + dy.reshape(flat_output_dims), + x.reshape(flat_input_dims), + grad_input=dx.reshape(flat_input_dims), + grad_input_transpose=dx_t, + otype=grad_input_dtype, + fp8_meta=grad_input_fp8_meta, + fp8_meta_index=0, + scale_inv=grad_input_fp8_scale_inv, + ) + dx = Float8Tensor( + data=dx, + fp8_meta=grad_input_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=grad_input_dtype, + fp8_scale_inv=grad_input_fp8_scale_inv, + dtype=dtype, + ) + dx._transpose = dx_t + dx._transpose_invalid = False + else: + # Standard impl + dx = self._activation_backward_impl(dy, x, TE_DType[dtype]) + if dx.size() != x.size(): + dx = dx.reshape(x.size()) + + # Note: This fails if op is preceeded by an identity op like Quantize(forward=False) + # # Clear input tensor if possible + # if ctx.prev_op is not None: + # clear_tensor_data(x) + + return dx, () diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 859b1ba1d7..ad86861114 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -83,6 +83,10 @@ class BasicLinear(BasicOperation): autograd. The weight's `main_grad` must be set externally and there is no guarantee that `grad` will be set or be meaningful. + userbuffers_options, dict, optional + Options for overlapping tensor-parallel communication with + compute using Userbuffers. This feature is highly + experimental. """ @@ -98,6 +102,7 @@ def __init__( sequence_parallel: bool = False, rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] = None, accumulate_into_main_grad: bool = False, + userbuffers_options: Optional[dict[str, Any]] = None, ) -> None: super().__init__() @@ -143,7 +148,7 @@ def __init__( ) # Whether weight tensor is natively in FP8 - self._with_fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() + self._with_fp8_parameters: bool = FP8GlobalStateManager.with_fp8_parameters() if self._with_fp8_parameters: self._fp8_metas = self._make_fp8_metas() @@ -163,7 +168,10 @@ def __init__( self.reset_parameters() # Whether to accumulate weight gradient into main_grad - self._accumulate_into_main_grad = accumulate_into_main_grad + self._accumulate_into_main_grad: bool = accumulate_into_main_grad + + # Userbuffers options + self._userbuffers_options: Optional[dict[str, Any]] = userbuffers_options @classmethod def _canonicalize_tensor_parallelism( @@ -308,8 +316,8 @@ def reset_parameters(self) -> None: weight = torch.nn.Parameter(weight) self.weight = weight - def pre_forward(self) -> None: - super().pre_forward() + def pre_forward(self, *args, **kwargs) -> None: + super().pre_forward(*args, **kwargs) if self.weight.device.type == "meta": self.reset_parameters() @@ -707,7 +715,7 @@ def _functional_backward( FP8 metadata for casting loss gradient w.r.t. output tensor to FP8. Required if output grad is not already in FP8. - grad_output_fp8_meta: dict, optional + grad_input_fp8_meta: dict, optional FP8 metadata for casting loss gradient w.r.t. input tensor to FP8 diff --git a/transformer_engine/pytorch/ops/basic/bias.py b/transformer_engine/pytorch/ops/basic/bias.py index 44a97b3b2d..eac1865566 100644 --- a/transformer_engine/pytorch/ops/basic/bias.py +++ b/transformer_engine/pytorch/ops/basic/bias.py @@ -111,8 +111,8 @@ def reset_parameters(self) -> None: bias = torch.nn.Parameter(bias) self.bias = bias - def pre_forward(self) -> None: - super().pre_forward() + def pre_forward(self, *args, **kwargs) -> None: + super().pre_forward(*args, **kwargs) if self.bias.device.type == "meta": self.reset_parameters() diff --git a/transformer_engine/pytorch/ops/basic/layer_norm.py b/transformer_engine/pytorch/ops/basic/layer_norm.py new file mode 100644 index 0000000000..710f838581 --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/layer_norm.py @@ -0,0 +1,326 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusable operation for Layer Normalization.""" + +from __future__ import annotations +from collections.abc import Iterable +import math +import os +from typing import Optional + +import torch + +from transformer_engine_torch import layernorm_bwd, layernorm_fwd +from ...cpp_extensions import ( + layernorm_fwd_fp8, + layernorm_fwd_fp8_inf, + layernorm_fwd_inf, +) +from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype +from ...tensor import Float8Tensor, QuantizedTensor +from ...utils import ( + canonicalize_device, + canonicalize_dtype, + clear_tensor_data, + devices_match, +) +from ..op import BasicOperation, OperationContext +from .._common import maybe_autocast_dtype, reshape + + +class LayerNorm(BasicOperation): + r"""Layer Normalization + + Applies Layer Normalization over a mini-batch of inputs as described in + the paper `Layer Normalization `__ + + .. math:: + y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta + + :math:`\gamma` and :math:`\beta` are learnable affine transform + parameters that match the inner-most dimensions of the input + tensor. + + Parameters + ---------- + normalized_shape: int or iterable of int + Inner dimensions of input tensor + eps : float, default = 1e-5 + A value added to the denominator of layer normalization for + numerical stability + device: torch.device, default = default CUDA device + Tensor device + dtype: torch.dtype, default = default dtype + Tensor datatype + zero_centered_gamma : bool, default = 'False' + If `True`, the :math:`\gamma` parameter is initialized to zero + and the calculation changes to + + .. math:: + y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta + + sm_margin: int or dict, default = 0 + Number of SMs to exclude when launching CUDA kernels. This + helps overlap with other kernels, e.g. communication kernels. + For more fine-grained control, provide a dict with the SM + margin at each compute stage ("forward", "backward", + "inference"). + + """ + + def __init__( + self, + normalized_shape: Iterable[int] | int, + *, + eps: float = 1e-5, + device: Optional[torch.device | str] = None, + dtype: Optional[torch.dtype] = None, + zero_centered_gamma: bool = False, + sm_margin: int | dict[str, int] = 0, + ) -> None: + super().__init__() + self.eps: float = eps + self.zero_centered_gamma: bool = zero_centered_gamma + + # Parameter shape + if not isinstance(normalized_shape, Iterable): + normalized_shape = (normalized_shape,) + else: + normalized_shape = tuple(normalized_shape) + + # Parameter device + defer_param_init = False + device = canonicalize_device(device) + if device.type == "meta": + defer_param_init = True + + # Initialize parameters if needed + dtype = canonicalize_dtype(dtype) + weight = torch.empty( + normalized_shape, + device=device, + dtype=dtype, + ) + bias = torch.empty( + normalized_shape, + device=device, + dtype=dtype, + ) + weight = torch.nn.Parameter(weight) + bias = torch.nn.Parameter(bias) + self.weight: torch.nn.Parameter + self.bias: torch.nn.Parameter + self.register_parameter("weight", weight) + self.register_parameter("bias", bias) + if not defer_param_init: + self.reset_parameters() + + # Number of SMs to exclude when launching CUDA kernels + self._sm_margins: dict[str, int] + if isinstance(sm_margin, dict): + + def getenv(name: str) -> int: + return int(os.getenv(name, "0")) + + self._sm_margins = { + "forward": sm_margin.get("forward", getenv("NVTE_FWD_LAYERNORM_SM_MARGIN")), + "backward": sm_margin.get("backward", getenv("NVTE_BWD_LAYERNORM_SM_MARGIN")), + "inference": sm_margin.get("inference", getenv("NVTE_INF_LAYERNORM_SM_MARGIN")), + } + else: + + def getenv(name: str) -> int: + return int(os.getenv(name, str(sm_margin))) + + self._sm_margins = { + "forward": getenv("NVTE_FWD_LAYERNORM_SM_MARGIN"), + "backward": getenv("NVTE_BWD_LAYERNORM_SM_MARGIN"), + "inference": getenv("NVTE_INF_LAYERNORM_SM_MARGIN"), + } + + def reset_parameters(self) -> None: + """Initialize parameter buffers and values""" + + # Parameter device + weight = self.weight + bias = self.bias + device = weight.device + if device.type == "meta": + device = canonicalize_device(None) + + # Initialize param buffers + if not devices_match(weight.device, device): + weight = torch.empty_like(weight, device=device) + if not devices_match(bias.device, device): + bias = torch.empty_like(bias, device=device) + + # Initialize values + if self.zero_centered_gamma: + torch.nn.init.zeros_(weight) + else: + torch.nn.init.ones_(weight) + torch.nn.init.zeros_(bias) + + # Save updated parameter + if not isinstance(weight, torch.nn.Parameter): + weight = torch.nn.Parameter(weight) + if not isinstance(bias, torch.nn.Parameter): + bias = torch.nn.Parameter(bias) + self.weight = weight + self.bias = bias + + def pre_forward(self, *args, **kwargs) -> None: + super().pre_forward(*args, **kwargs) + if self.weight.device.type == "meta" or self.bias.device.type == "meta": + self.reset_parameters() + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op: Optional[BasicOperation] = None, + next_op: Optional[BasicOperation] = None, + ) -> torch.Tensor: + + # Check tensor dims + weight = self.weight + weight_dims = tuple(weight.size()) + input_dims = tuple(input_.size()) + if len(input_dims) < len(weight_dims) or input_dims[-len(weight_dims) :] != weight_dims: + raise ValueError( + f"Input tensor (shape={input_dims}) " + f"and weight tensor (shape={weight_dims}) are not compatible" + ) + + # Check input tensors + inner_dim = math.prod(weight_dims) + device = weight.device + if device.type != "cuda": + device = canonicalize_device(None) + dtype = maybe_autocast_dtype(default_dtype=weight.dtype) + x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype) + w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) + b = reshape(self.bias, (inner_dim,), device=device, dtype=dtype) + if isinstance(x, QuantizedTensor): + x = x.dequantize() + if isinstance(w, QuantizedTensor): + w = w.dequantize() + if isinstance(b, QuantizedTensor): + b = b.dequantize() + + # Check if backward pass is needed + requires_grad = ctx.requires_grad + + # Check if FP8 is enabled + with_fp8_output = ( + FP8GlobalStateManager.is_fp8_enabled() + and next_op is not None + and next_op.num_fp8_scales("input") > 0 + ) + output_fp8_meta = None + if with_fp8_output: + output_fp8_meta = next_op.get_fp8_meta("input") + + # Compute layer norm + y = None + means = None + rstdevs = None + sm_margin = self._sm_margins["forward" if requires_grad else "inference"] + if with_fp8_output: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) + fp8_dtype = get_fp8_te_dtype(output_fp8_meta["recipe"], fprop_tensor=True) + args = ( + x, + w, + b, + self.eps, + output_fp8_meta[fp8_meta_key], + 0, # fp8_meta_index + fp8_dtype, + sm_margin, + self.zero_centered_gamma, + ) + if requires_grad: + data, means, rstdevs = layernorm_fwd_fp8(*args) + else: + data = layernorm_fwd_fp8_inf(*args) + y = Float8Tensor( + data=data, + fp8_meta=output_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + dtype=dtype, + ) + else: + args = ( + x, + w, + b, + self.eps, + sm_margin, + self.zero_centered_gamma, + ) + if requires_grad: + y, means, rstdevs = layernorm_fwd(*args) + else: + y = layernorm_fwd_inf(*args) + + # Save state for backward pass + if requires_grad: + ctx.save_for_backward(x, means, rstdevs) + ctx.device = device + ctx.dtype = dtype + ctx.has_prev_op = prev_op is not None + + # Reshape output tensor + out = reshape(y, input_dims) + return out + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + + # Saved tensors from forward pass + x, means, rstdevs = ctx.saved_tensors + + # Tensor dims + weight_dims = self.weight.size() + inner_dim = math.prod(weight_dims) + + # Check input tensors + device = ctx.device + dtype = ctx.dtype + dy = reshape(grad_output, x.size(), device=device, dtype=dtype) + w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) + if isinstance(w, QuantizedTensor): + w = w.dequantize() + if isinstance(dy, QuantizedTensor): + dy = dy.dequantize() + + # Compute layer norm backward pass + dx, dw, db = layernorm_bwd( + dy, + x, + means, + rstdevs, + w, + self._sm_margins["backward"], + self.zero_centered_gamma, + ) + + # Clear saved tensors if possible + if ctx.has_prev_op: + clear_tensor_data(x) + clear_tensor_data(means) + clear_tensor_data(rstdevs) + + # Reshape results + grad_input = reshape(dx, grad_output.size()) + grad_weight = reshape(dw, weight_dims) + grad_bias = reshape(db, weight_dims) + return grad_input, (grad_weight, grad_bias) diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py new file mode 100644 index 0000000000..313b6e5583 --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -0,0 +1,93 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operation for quantization.""" + +from __future__ import annotations +from typing import Optional + +import torch + +from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype +from ...tensor import Float8Tensor, QuantizedTensor +from ..op import BasicOperation, OperationContext + + +class Quantize(BasicOperation): + """Quantize tensor data + + Uses FP8 recipe from `fp8_autocast` context. When called outside + of an `fp8_autocast` context, this is an identity operation. + + Parameters + ---------- + forward: bool, default = `True` + Perform quantization in forward pass + backward: bool, default = `False` + Perform quantization in backward pass + + """ + + def __init__( + self, + forward: bool = True, + backward: bool = False, + ) -> None: + super().__init__() + self._quantize_forward = forward + self._quantize_backward = backward + + def num_fp8_scales(self, mode: str) -> int: + if mode == "input" and self._quantize_forward: + return 1 + if mode == "grad_output" and self._quantize_backward: + return 1 + return 0 + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op: Optional[BasicOperation] = None, + next_op: Optional[BasicOperation] = None, + ) -> torch.Tensor: + + # Check if FP8 is enabled + fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() + quantize_forward = fp8_enabled and self._quantize_forward + quantize_backward = fp8_enabled and self._quantize_backward + + # Quantize if needed + out = input_ + if quantize_forward and not isinstance(out, QuantizedTensor): + fp8_meta = self.get_fp8_meta("input") + fp8_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + out = Float8Tensor.to_float8( + out, + fp8_meta=fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + ) + + ctx.quantize_backward = quantize_backward + return out + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + grad_input = grad_output + if ctx.quantize_backward and not isinstance(grad_input, QuantizedTensor): + fp8_meta = self.get_fp8_meta("grad_output") + fp8_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False) + grad_input = Float8Tensor.to_float8( + grad_input, + fp8_meta=fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + ) + return grad_input, () diff --git a/transformer_engine/pytorch/ops/basic/rmsnorm.py b/transformer_engine/pytorch/ops/basic/rmsnorm.py new file mode 100644 index 0000000000..84f05ce713 --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/rmsnorm.py @@ -0,0 +1,300 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusable operation for RMSNorm.""" + +from __future__ import annotations +from collections.abc import Iterable +import math +import os +from typing import Optional + +import torch + +from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd +from ...cpp_extensions import ( + rmsnorm_fwd_fp8, + rmsnorm_fwd_fp8_inf, + rmsnorm_fwd_inf, +) +from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype +from ...tensor import Float8Tensor, QuantizedTensor +from ...utils import ( + canonicalize_device, + canonicalize_dtype, + clear_tensor_data, + devices_match, +) +from ..op import BasicOperation, OperationContext +from .._common import maybe_autocast_dtype, reshape + + +class RMSNorm(BasicOperation): + r"""Root Mean Square Layer Normalization + + Applies Root Mean Square Layer Normalization over a mini-batch of + inputs as described in the paper + `Root Mean Square Layer Normalization `__ + + .. math:: + y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + + :math:`\gamma` is a learnable affine transform parameter that + matches the inner-most dimensions of the input tensor. + + Parameters + ---------- + normalized_shape: int or iterable of int + Inner dimensions of input tensor + eps : float, default = 1e-5 + A value added to the denominator for numerical stability + device: torch.device, default = default CUDA device + Tensor device + dtype: torch.dtype, default = default dtype + Tensor datatype + zero_centered_gamma : bool, default = 'False' + If `True`, the :math:`\gamma` parameter is initialized to zero + and the calculation changes to + + .. math:: + y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + + sm_margin: int, default = 0 + Number of SMs to exclude when launching CUDA kernels. This + helps overlap with other kernels, e.g. communication kernels. + For more fine-grained control, provide a dict with the SM + margin at each compute stage ("forward", "backward", + "inference"). + + """ + + def __init__( + self, + normalized_shape: Iterable[int] | int, + *, + eps: float = 1e-5, + device: Optional[torch.device | str] = None, + dtype: Optional[torch.dtype] = None, + zero_centered_gamma: bool = False, + sm_margin: int = 0, + ) -> None: + super().__init__() + self.eps: float = eps + self.zero_centered_gamma: bool = zero_centered_gamma + + # Parameter shape + if not isinstance(normalized_shape, Iterable): + normalized_shape = (normalized_shape,) + else: + normalized_shape = tuple(normalized_shape) + + # Parameter device + defer_param_init = False + device = canonicalize_device(device) + if device.type == "meta": + defer_param_init = True + + # Initialize parameters if needed + weight = torch.empty( + normalized_shape, + device=device, + dtype=canonicalize_dtype(dtype), + ) + weight = torch.nn.Parameter(weight) + self.weight: torch.nn.Parameter + self.register_parameter("weight", weight) + if not defer_param_init: + self.reset_parameters() + + # Number of SMs to exclude when launching CUDA kernels + self._sm_margins: dict[str, int] + if isinstance(sm_margin, dict): + + def getenv(name: str) -> int: + return int(os.getenv(name, "0")) + + self._sm_margins = { + "forward": sm_margin.get("forward", getenv("NVTE_FWD_LAYERNORM_SM_MARGIN")), + "backward": sm_margin.get("backward", getenv("NVTE_BWD_LAYERNORM_SM_MARGIN")), + "inference": sm_margin.get("inference", getenv("NVTE_INF_LAYERNORM_SM_MARGIN")), + } + else: + + def getenv(name: str) -> int: + return int(os.getenv(name, str(sm_margin))) + + self._sm_margins = { + "forward": getenv("NVTE_FWD_LAYERNORM_SM_MARGIN"), + "backward": getenv("NVTE_BWD_LAYERNORM_SM_MARGIN"), + "inference": getenv("NVTE_INF_LAYERNORM_SM_MARGIN"), + } + + def reset_parameters(self) -> None: + """Initialize parameter buffers and values""" + + # Parameter device + weight = self.weight + device = weight.device + if device.type == "meta": + device = canonicalize_device(None) + + # Initialize param buffers + if not devices_match(weight.device, device): + weight = torch.empty_like(weight, device=device) + + # Initialize values + if self.zero_centered_gamma: + torch.nn.init.zeros_(weight) + else: + torch.nn.init.ones_(weight) + + # Save updated parameter + if not isinstance(weight, torch.nn.Parameter): + weight = torch.nn.Parameter(weight) + self.weight = weight + + def pre_forward(self, *args, **kwargs) -> None: + super().pre_forward(*args, **kwargs) + if self.weight.device.type == "meta": + self.reset_parameters() + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op: Optional[BasicOperation] = None, + next_op: Optional[BasicOperation] = None, + ) -> torch.Tensor: + + # Check tensor dims + weight = self.weight + weight_dims = tuple(weight.size()) + input_dims = tuple(input_.size()) + if len(input_dims) < len(weight_dims) or input_dims[-len(weight_dims) :] != weight_dims: + raise ValueError( + f"Input tensor (shape={input_dims}) " + f"and weight tensor (shape={weight_dims}) are not compatible" + ) + + # Check input tensors + inner_dim = math.prod(weight_dims) + device = weight.device + if device.type != "cuda": + device = canonicalize_device(None) + dtype = maybe_autocast_dtype(default_dtype=weight.dtype) + x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype) + w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) + if isinstance(x, QuantizedTensor): + x = x.dequantize() + if isinstance(w, QuantizedTensor): + w = w.dequantize() + + # Check if backward pass is needed + requires_grad = ctx.requires_grad + + # Check if FP8 is enabled + with_fp8_output = ( + FP8GlobalStateManager.is_fp8_enabled() + and next_op is not None + and next_op.num_fp8_scales("input") > 0 + ) + output_fp8_meta = None + if with_fp8_output: + output_fp8_meta = next_op.get_fp8_meta("input") + + # Compute RMSNorm + y = None + rstdevs = None + sm_margin = self._sm_margins["forward" if requires_grad else "inference"] + if with_fp8_output: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) + fp8_dtype = get_fp8_te_dtype(output_fp8_meta["recipe"], fprop_tensor=True) + args = ( + x, + w, + self.eps, + output_fp8_meta[fp8_meta_key], + 0, # fp8_meta_index + fp8_dtype, + sm_margin, + self.zero_centered_gamma, + ) + if requires_grad: + data, rstdevs = rmsnorm_fwd_fp8(*args) + else: + data = rmsnorm_fwd_fp8_inf(*args) + y = Float8Tensor( + data=data, + fp8_meta=output_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + dtype=dtype, + ) + else: + args = ( + x, + w, + self.eps, + sm_margin, + self.zero_centered_gamma, + ) + if requires_grad: + y, rstdevs = rmsnorm_fwd(*args) + else: + y = rmsnorm_fwd_inf(*args) + + # Save state for backward pass + if requires_grad: + ctx.save_for_backward(x, rstdevs) + ctx.device = device + ctx.dtype = dtype + ctx.has_prev_op = prev_op is not None + + # Reshape output tensor + out = reshape(y, input_dims) + return out + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + + # Saved tensors from forward pass + x, rstdevs = ctx.saved_tensors + + # Tensor dims + weight_dims = self.weight.size() + inner_dim = math.prod(weight_dims) + + # Check input tensors + device = ctx.device + dtype = ctx.dtype + dy = reshape(grad_output, x.size(), device=device, dtype=dtype) + w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) + if isinstance(w, QuantizedTensor): + w = w.dequantize() + if isinstance(dy, QuantizedTensor): + dy = dy.dequantize() + + # Compute RMSNorm backward pass + dx, dw = rmsnorm_bwd( + dy, + x, + rstdevs, + w, + self._sm_margins["backward"], + self.zero_centered_gamma, + ) + + # Clear saved tensors if possible + if ctx.has_prev_op: + clear_tensor_data(x) + clear_tensor_data(rstdevs) + + # Reshape results + grad_input = reshape(dx, grad_output.size()) + grad_weight = reshape(dw, weight_dims) + return grad_input, (grad_weight,) diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index bd832254d8..08b9f06123 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -16,3 +16,11 @@ ForwardLinearBiasAdd, fuse_forward_linear_bias_add, ) +from .userbuffers_backward_linear import ( + UserbuffersBackwardLinear, + fuse_userbuffers_backward_linear, +) +from .userbuffers_forward_linear import ( + UserbuffersForwardLinear, + fuse_userbuffers_forward_linear, +) diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py new file mode 100644 index 0000000000..907cff1c81 --- /dev/null +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -0,0 +1,781 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Linear layer backward with Userbuffers communication.""" + +from __future__ import annotations +from collections.abc import Iterable +from typing import Any, Optional +import warnings + +import torch + +from transformer_engine_torch import CommOverlapAlgo +from ...cpp_extensions import ( + fp8_cast_transpose_bgrad_fused, + fp8_gemm, + gemm, +) +from ...distributed import get_distributed_world_size +from ...float8_tensor import Float8Tensor +from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype +from ...module.base import get_ub, get_workspace +from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data +from ..basic import BasicLinear, Bias, ReduceScatter +from ..op import FusedOperation, FusibleOperation, OperationContext +from .._common import ( + convert_tensor, + get_fp8_meta_from_fp8_tensor, + is_float8_tensor, + reshape, +) + + +class UserbuffersBackwardLinear(FusedOperation): + """Linear backward implementation using Userbuffers + + This operation is equivalent to a linear operation's backward + pass, but it uses Userbuffers to overlap tensor-parallel + communication with compute. + + """ + + def __init__( + self, + *, + linear: BasicLinear, + bias: Optional[Bias], + reduce_scatter: Optional[ReduceScatter], + ) -> None: + + # Basic operations that comprise this fused operation + op_idxs = {"linear": None, "bias": None, "reduce_scatter": None} + ops = [] + if reduce_scatter is not None: + op_idxs["reduce_scatter"] = len(ops) + ops.append(reduce_scatter) + if bias is not None: + op_idxs["bias"] = len(ops) + ops.append(bias) + op_idxs["linear"] = len(ops) + ops.append(linear) + + # Initialize base class + super().__init__(ops) + + # Index of each basic operations + self._op_idxs: dict[str, Optional[int]] = op_idxs + + # Tensor parallelism configuration + self.tensor_parallel_mode: Optional[str] + self.tensor_parallel_group: Optional[torch.distributed.ProcessGroup] + self.tensor_parallel_size: int + self.sequence_parallel: bool + if reduce_scatter is None: + self.tensor_parallel_mode = linear.tensor_parallel_mode + self.tensor_parallel_group = linear.tensor_parallel_group + self.tensor_parallel_size = linear.tensor_parallel_size + self.sequence_parallel = linear.sequence_parallel + else: + self.tensor_parallel_mode = "row" + self.tensor_parallel_group = reduce_scatter.process_group + self.tensor_parallel_size = reduce_scatter.process_group_size + self.sequence_parallel = True + + @staticmethod + def _functional_backward( + grad_output: torch.Tensor, + input: Optional[torch.Tensor], # pylint: disable=redefined-builtin + weight: Optional[torch.Tensor], + input_dims: Iterable[int], + weight_dims: Iterable[int], + *, + weight_requires_grad: bool = True, + bias_requires_grad: bool = False, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + grad_weight: Optional[torch.Tensor] = None, + accumulate_into_grad_weight: bool = False, + tensor_parallel_mode: Optional[str] = None, + tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + tensor_parallel_size: Optional[int] = None, + sequence_parallel: bool = False, + with_fp8_compute: bool = False, + input_fp8_meta: Optional[dict[str, Any]] = None, + weight_fp8_meta: Optional[dict[str, Any]] = None, + grad_output_fp8_meta: Optional[dict[str, Any]] = None, + grad_input_fp8_meta: Optional[dict[str, Any]] = None, + ub_comm_name: str, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], dict]: + """Functional API for backward pass + + Parameters + ---------- + grad_output: torch.Tensor + Loss gradient w.r.t. output tensor + input: torch.Tensor, optional + Input tensor. Required to compute loss gradient w.r.t. + weight. + weight: torch.Tensor, optional + Weight tensor. Required to compute loss gradient w.r.t. + input. + input_dims: iterable of int + Input tensor dimensions + weight_dims: iterable of int + Weight tensor dimensions + weight_requires_grad: bool + Whether to compute loss gradient w.r.t. weight tensor + bias_requires_grad: bool + Whether to compute loss gradient w.r.t. bias tensor + device: torch.device, default = default CUDA device + Tensor device + dtype: torch.dtype, default = default dtype + Tensor datatype + grad_weight: torch.Tensor, optional + Loss gradient w.r.t. weight tensor + accumulate_into_grad_weight: bool, default = `False` + Add result to weight grad instead of overwriting + tensor_parallel_mode: {`None`, "column", "row"}, default = `None` + Mode for tensor parallelism + tensor_parallel_group: torch.distributed.ProcessGroup, default = world group + Process group for tensor parallelism + sequence_parallel: bool, default = `False` + Whether to apply sequence parallelism together with tensor + parallelism, i.e. distributing input or output tensors + along outer dimension (sequence or batch dim) when not + distributing along inner dimension (embedding dim) + with_fp8_compute: bool, default = `False` + Whether to perform compute in FP8 + input_fp8_meta: dict, optional + FP8 metadata for casting input tensor to FP8. Required for + FP8 compute if input is not already in FP8. + weight_fp8_meta: dict, optional + FP8 metadata for casting weight tensor to FP8. Required for + FP8 compute if weight is not already in FP8. + grad_output_fp8_meta: dict, optional + FP8 metadata for casting loss gradient w.r.t. output + tensor to FP8. Required if output grad is not already in + FP8. + grad_input_fp8_meta: dict, optional + FP8 metadata for casting loss gradient w.r.t. input + tensor to FP8 + ub_comm_name: str + Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is + used to access the corresponding Userbuffers communicators + (e.g. "qkv_dgrad", "qkv_wgrad"). + + Returns + ------- + torch.Tensor + Loss gradient w.r.t. input tensor + torch.Tensor + Loss gradient w.r.t. weight tensor + dict + Extra output tensors. "grad_bias" is loss gradient w.r.t. + the bias tensor. + + """ + + # Configuration-specific outputs + extra_outputs = {} + + # Check device + if device is None: + device = weight.device + device = canonicalize_device(device) + if device.type != "cuda": + raise ValueError(f"Only CUDA devices are supported (got {device})") + + # Check datatype + if dtype is None: + dtype = weight.dtype + dtype = canonicalize_dtype(dtype) + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") + + # Input tensor dims + output_dims = tuple(grad_output.size()) + input_dims = tuple(input_dims) + weight_dims = tuple(weight_dims) + if len(weight_dims) != 2: + raise ValueError(f"Weight tensor is not 2D (shape={weight_dims})") + if len(input_dims) == 0 or weight_dims[1] != input_dims[-1]: + raise ValueError( + f"Input tensor (shape={input_dims}) " + f"and weight tensor (shape={weight_dims}) " + "are not compatible" + ) + if weight_dims[0] != output_dims[-1]: + raise ValueError( + f"Grad output tensor (shape={output_dims}) " + f"and weight tensor (shape={weight_dims}) " + "are not compatible" + ) + + # Check tensor parallel group + if tensor_parallel_size is None: + tensor_parallel_size = get_distributed_world_size(tensor_parallel_group) + if tensor_parallel_size == 1: + tensor_parallel_mode = None + if tensor_parallel_mode not in ("column", "row"): + raise RuntimeError( + "Invalid configuration for Userbuffers " + f"({tensor_parallel_size=}, {tensor_parallel_mode=})" + ) + if not sequence_parallel: + raise RuntimeError(f"Invalid configuration for Userbuffers ({sequence_parallel=})") + + # Check if FP8 is enabled + if with_fp8_compute: + if grad_output_fp8_meta is None and not is_float8_tensor(grad_output): + raise ValueError("No FP8 metadata was provided for casting output gradient to FP8") + else: + input_fp8_meta = None + weight_fp8_meta = None + grad_output_fp8_meta = None + grad_input_fp8_meta = None + with_fp8_grad_input = ( + with_fp8_compute + and tensor_parallel_mode != "column" + and grad_input_fp8_meta is not None + ) + + # Get Userbuffers communicators and algorithms + # Note: communication patterns are (1) overlap dy all-gather + # with dgrad GEMM, (2) overlap x all-gather with dgrad GEMM + # and dx reduce-scatter with wgrad GEMM, (3) overlap dx + # reduce-scatter with dgrad GEMM. + with_ub_all_gather_dy = False + with_ub_reduce_scatter_dx = False + with_ub_all_gather_x = False + ub_comm_dy = None + ub_comm_dx = None + ub_comm_x = None + ub_algo_dy = None + ub_algo_dx = None + ub_algo_x = None + if tensor_parallel_mode == "row": + with_ub_all_gather_dy = True + ub_comm_dy = get_ub(ub_comm_name + "_dgrad") + if with_fp8_compute and ub_comm_dy.is_atomic_gemm(): + ub_algo_dy = CommOverlapAlgo.ATOMIC_GEMM_AG_P2P + else: + ub_algo_dy = CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P + elif tensor_parallel_mode == "column": + with_ub_reduce_scatter_dx = True + if weight_requires_grad: + with_ub_all_gather_x = True + ub_comm_dx = get_ub(ub_comm_name + "_wgrad") + ub_comm_x = get_ub(ub_comm_name + "_dgrad") + ub_algo_dx = CommOverlapAlgo.BULK_OVERLAP_RS + ub_algo_x = CommOverlapAlgo.BULK_OVERLAP_AG + else: + with_ub_all_gather_x = False + ub_comm_dx = get_ub(ub_comm_name + "_dgrad") + is_atomic_gemm = with_fp8_compute and ub_comm_dx.is_atomic_gemm() + ub_algo_dx = { + (True, True): CommOverlapAlgo.ATOMIC_GEMM_RS_P2P, + (True, False): CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P, + (False, True): CommOverlapAlgo.ATOMIC_GEMM_RS, + (False, False): CommOverlapAlgo.SPLIT_PIPELINED_RS, + }[(ub_comm_dx.is_p2p_overlap(), is_atomic_gemm)] + + # Check grad output tensor + # Note: Possibly fuse cast with computing grad bias + dy_local = reshape( + grad_output, + (-1, output_dims[-1]), + device=device, + dtype=dtype, + ) + db = None + db_async = None + if bias_requires_grad and with_fp8_compute and with_ub_all_gather_dy: + # We don't have a grad bias impl that takes FP8 input. For + # cases where we cast to FP8 and all-gather, it's better + # to compute the grad bias on ungathered, non-FP8 values. + db = dy_local.sum(dim=0) + db_async = torch.distributed.all_reduce( + db, + group=tensor_parallel_group, + async_op=True, + ) + if with_fp8_compute and not is_float8_tensor(dy_local): + fp8_dtype = get_fp8_te_dtype( + grad_output_fp8_meta["recipe"], + fprop_tensor=False, + ) + if bias_requires_grad and db is None: + # Fused cast-transpose-bgrad + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False) + fp8_scale_inv = torch.empty([1], dtype=torch.float32, device=device) + db, data, data_transpose = fp8_cast_transpose_bgrad_fused( + dy_local, + grad_output_fp8_meta[fp8_meta_key], + 0, + fp8_dtype, + scale_inv=fp8_scale_inv, + ) + if with_ub_all_gather_dy: + data = ub_comm_dy.get_ubuf_output(0).copy_(data) + dy_local = Float8Tensor( + data=data, + fp8_meta=grad_output_fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + dtype=dtype, + data_transpose=data_transpose, + ) + else: + dy_local = Float8Tensor.to_float8( + dy_local, + fp8_meta=grad_output_fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + data=(ub_comm_dy.get_ubuf_output(0) if with_ub_all_gather_dy else None), + with_transpose_cache=(not with_ub_all_gather_dy), + ) + elif not with_fp8_compute and is_float8_tensor(dy_local): + if with_ub_all_gather_dy: + ub_local_buffer = ub_comm_dy.get_ubuf_output(0) + dy_local = ub_local_buffer.copy_(dy_local) + else: + dy_local = dy_local.dequantize() + + if bias_requires_grad and db is None and with_fp8_compute and with_ub_all_gather_dy: + # We don't have a fused grad bias impl that takes FP8 + # input. For cases where we cast to FP8 and all-gather, + # it's better to compute the grad bias on ungathered, + # non-FP8 values. + db = dy_local.sum(dim=0) + db_async = torch.distributed.all_reduce( + db, + group=tensor_parallel_group, + async_op=True, + ) + + # Check input tensor + x_local = None + if weight_requires_grad: + x_local = reshape( + input, + (-1, input_dims[-1]), + device=device, + dtype=dtype, + ) + if with_fp8_compute and not is_float8_tensor(x_local): + fp8_dtype = get_fp8_te_dtype( + input_fp8_meta["recipe"], + fprop_tensor=True, + ) + x_local = Float8Tensor.to_float8( + x_local, + fp8_meta=input_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + data=(ub_comm_x.get_ubuf_output(0) if with_ub_all_gather_x else None), + with_transpose_cache=(not with_ub_all_gather_x), + ) + elif not with_fp8_compute and is_float8_tensor(x_local): + if with_ub_all_gather_x: + ub_local_buffer = ub_comm_x.get_ubuf_output(0) + x_local = ub_local_buffer.copy_(x_local) + else: + x_local = x_local.dequantize() + + # Check weight tensor + w = convert_tensor( + weight, + device=device, + dtype=dtype, + memory_format=torch.contiguous_format, + ) + if with_fp8_compute and not is_float8_tensor(w): + fp8_dtype = get_fp8_te_dtype( + weight_fp8_meta["recipe"], + fprop_tensor=True, + ) + w = Float8Tensor.to_float8( + w, + fp8_meta=weight_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + with_transpose_cache=True, + ) + elif not with_fp8_compute and is_float8_tensor(w): + w = w.dequantize() + + # Initialize buffers for UB all-gather if needed + dy = dy_local + x = x_local + if with_ub_all_gather_dy: + ub_local_buffer = ub_comm_dy.get_ubuf_output(0) + ub_global_buffer = ub_comm_dy.get_ubuf_output(1) + if with_fp8_compute: + dy = Float8Tensor.make_like(dy_local, data=ub_global_buffer) + if dy_local._data.data_ptr() != ub_local_buffer.data_ptr(): + ub_local_buffer.copy_(dy_local._data) + else: + dy = ub_global_buffer + if dy_local.data_ptr() != ub_local_buffer.data_ptr(): + ub_local_buffer.copy_(dy_local) + if with_ub_all_gather_x: + ub_local_buffer = ub_comm_x.get_ubuf_output(0) + ub_global_buffer = ub_comm_x.get_ubuf_output(1) + if with_fp8_compute: + x = Float8Tensor.make_like(x_local, data=ub_global_buffer) + if x_local._data.data_ptr() != ub_local_buffer.data_ptr(): + ub_local_buffer.copy_(x_local._data) + else: + x = ub_global_buffer + if x_local.data_ptr() != ub_local_buffer.data_ptr(): + ub_local_buffer.copy_(x_local) + + # Construct grad input tensor + dx = None + dx_local = None + if with_ub_reduce_scatter_dx: + # Initialize buffers for UB reduce-scatter + dx = ub_comm_dx.get_ubuf_output(1) + ub_local_buffer = ub_comm_dx.get_ubuf_output(0) + if with_ub_all_gather_x: + dx_local = ub_local_buffer + else: + dx_local = torch.empty_like(ub_local_buffer) + else: + # Allocate grad input tensor + if with_fp8_grad_input: + fp8_dtype = get_fp8_te_dtype( + grad_input_fp8_meta["recipe"], + fprop_tensor=False, + ) + data = torch.empty( + (dy.size(0), w.size(-1)), + dtype=torch.uint8, + device=device, + ) + dx = Float8Tensor( + data=data, + fp8_meta=grad_input_fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + dtype=dtype, + ) + else: + dx = torch.empty( + (dy.size(0), w.size(-1)), + dtype=dtype, + device=device, + ) + dx_local = dx + + # Allocate grad input tensor + if grad_weight is None: + if accumulate_into_grad_weight: + raise ValueError( + "Attempted to accumulate into grad weight bufferwithout providing grad weight" + ) + grad_weight = torch.empty( + weight_dims, + dtype=dtype, + device=device, + memory_format=torch.contiguous_format, + ) + + # Perform dgrad GEMM + if with_fp8_compute: + kwargs = {"out": dx, "use_split_accumulator": True} + if with_ub_all_gather_dy: + kwargs["ub_algo"] = ub_algo_dy + kwargs["ub"] = ub_comm_dy + elif with_ub_all_gather_x: + kwargs["ub_algo"] = ub_algo_x + kwargs["ub"] = ub_comm_x + elif with_ub_reduce_scatter_dx: + kwargs["ub_algo"] = ub_algo_dx + kwargs["ub"] = ub_comm_dx + kwargs["extra_output_tensor"] = dx_local + if with_fp8_grad_input: + fp8_meta, fp8_meta_index = get_fp8_meta_from_fp8_tensor(dx) + kwargs.update( + { + "out": dx._data, + "out_index": fp8_meta_index, + "fp8_meta_tensor": fp8_meta, + "D_dtype": dx._fp8_dtype, + } + ) + fp8_gemm( + w.transpose_2d(), + w._scale_inv, + 0, + w._fp8_dtype, + dy._data, + dy._scale_inv, + 0, + dy._fp8_dtype, + dy.dtype, + get_workspace(), + **kwargs, + ) + else: + kwargs = {"grad": True, "layout": "NN", "out": dx} + if with_ub_all_gather_dy: + kwargs["ub_algo"] = ub_algo_dy + kwargs["ub"] = ub_comm_dy + elif with_ub_all_gather_x: + kwargs["ub_algo"] = ub_algo_x + kwargs["ub"] = ub_comm_x + elif with_ub_reduce_scatter_dx: + kwargs["ub_algo"] = ub_algo_dx + kwargs["ub"] = ub_comm_dx + kwargs["extra_output_tensor"] = dx_local + gemm(w, dy, dx.dtype, get_workspace(), **kwargs) + grad_input = reshape(dx_local, input_dims) + + # Perform wgrad GEMM + if not weight_requires_grad: + pass + elif with_fp8_compute: + kwargs = { + "accumulate": accumulate_into_grad_weight, + "out": grad_weight, + "use_split_accumulator": True, + } + if with_ub_reduce_scatter_dx: + kwargs["ub_algo"] = ub_algo_dx + kwargs["ub"] = ub_comm_dx + fp8_gemm( + x.transpose_2d(), + x._scale_inv, + 0, + x._fp8_dtype, + dy.transpose_2d(), + dy._scale_inv, + 0, + dy._fp8_dtype, + grad_weight.dtype, + get_workspace(), + **kwargs, + ) + else: + kwargs = { + "accumulate": accumulate_into_grad_weight, + "layout": "NT", + "grad": True, + "use_bias": bias_requires_grad, + "out": grad_weight, + } + if with_ub_reduce_scatter_dx: + kwargs["ub_algo"] = ub_algo_dx + kwargs["ub"] = ub_comm_dx + grad_weight, db, _ = gemm( + x, + dy, + grad_weight.dtype, + get_workspace(), + **kwargs, + ) + + # Compute grad bias if needed + if db_async is not None: + db_async.wait() + if bias_requires_grad: + if db is None: + db = dy.sum(dim=0) + extra_outputs["grad_bias"] = db + + return grad_input, grad_weight, extra_outputs + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + *, + basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], + ) -> tuple[ + torch.Tensor, + list[tuple[Optional[torch.Tensor], ...]], + list[tuple[()]], + ]: + + # Get basic operations + idx = self._op_idxs["linear"] + linear_op = self.basic_ops[idx] + linear_op_ctx = basic_op_ctxs[idx] + bias_op = None + if self._op_idxs["bias"] is not None: + idx = self._op_idxs["bias"] + bias_op = self.basic_ops[idx] + + # Saved tensors from forward pass + (x_local,) = linear_op_ctx.saved_tensors + + # wgrad fusion + accumulate_into_main_grad = linear_op._accumulate_into_main_grad + grad_weight = None + if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad: + if not hasattr(linear_op.weight, "main_grad"): + raise RuntimeError( + "BasicLinear op is configured with " + "accumulate_into_main_grad=True, " + "but weight parameter does not have main_grad attribute" + ) + grad_weight = linear_op.weight.main_grad.detach() + else: + accumulate_into_main_grad = False + + # Hackily workaround Userbuffers bug with non-FP8 dgrad + # reduce-scatter overlap + weight_requires_grad = linear_op_ctx.weight_requires_grad + if not linear_op_ctx.with_fp8_compute and not weight_requires_grad: + warnings.warn( + "There is a correctness bug when using Userbuffers " + "to overlap a dgrad reduce-scatter with a non-FP8 dgrad GEMM. " + "Hackily working around by overlapping dgrad reduce-scatter " + "with wgrad GEMM, even though wgrad isn't needed. " + "Please contact Transformer Engine team " + "if you encounter this use-case." + ) + weight_requires_grad = True + + # Linear backward pass + retval = UserbuffersBackwardLinear._functional_backward( + grad_output=grad_output, + input=x_local, + weight=linear_op.weight, + input_dims=linear_op_ctx.input_dims, + weight_dims=linear_op.weight.size(), + weight_requires_grad=weight_requires_grad, + bias_requires_grad=(bias_op is not None), + device=linear_op.device, + dtype=linear_op_ctx.dtype, + grad_weight=grad_weight, + accumulate_into_grad_weight=accumulate_into_main_grad, + tensor_parallel_mode=self.tensor_parallel_mode, + tensor_parallel_group=self.tensor_parallel_group, + sequence_parallel=self.sequence_parallel, + with_fp8_compute=linear_op_ctx.with_fp8_compute, + weight_fp8_meta=linear_op_ctx.weight_fp8_meta, + grad_output_fp8_meta=linear_op_ctx.grad_output_fp8_meta, + grad_input_fp8_meta=linear_op_ctx.grad_input_fp8_meta, + ub_comm_name=linear_op._userbuffers_options["comm_name"], + ) + grad_input, grad_weight, extra_outputs = retval + grad_bias = None + if bias_op is not None: + grad_bias = extra_outputs["grad_bias"] + + # Clear input tensor if possible + if linear_op_ctx.has_prev_op: + clear_tensor_data(x_local) + + # Return gradients + grad_params = [() for _ in range(len(self.basic_ops))] + if accumulate_into_main_grad: + grad_weight = None + grad_params[self._op_idxs["linear"]] = (grad_weight,) + if bias_op is not None: + grad_params[self._op_idxs["bias"]] = (grad_bias,) + grad_extra_inputs = [() for _ in range(len(self.basic_ops))] + return grad_input, grad_params, grad_extra_inputs + + +def fuse_userbuffers_backward_linear( + ops: list[tuple[FusibleOperation, list[int]]], +) -> list[tuple[FusibleOperation, list[int]]]: + """Substitute linear operations with Userbuffers implementation + + Parameters + ---------- + ops: list of tuples + Forward pass operations and the indices of the corresponding + basic operations. + + Returns + ------- + ops: list of tuples + Updated forward pass operations + + """ + + # Return immediately if environment is not distributed + if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: + return ops + + # Sliding window in list of ops + window = [] + + def peek_next_op() -> Optional[FusibleOperation]: + """Get next op in list of ops""" + nonlocal ops + if not ops: + return None + return ops[-1][0] + + def pop_next_op() -> FusibleOperation: + """Remove next op from list of ops and add to sliding window""" + nonlocal ops, window + window.insert(0, ops[-1]) + ops = ops[:-1] + return window[0][0] + + # Scan through ops in reverse order, fusing if possible + out_reversed = [] + while ops: + out_reversed.extend(reversed(window)) + window.clear() + + # Check if next op is linear + next_op = pop_next_op() + if not isinstance(next_op, BasicLinear): + continue + linear = next_op + if linear._userbuffers_options is None: + continue + + # Check if next op is bias + bias = None + if linear.tensor_parallel_mode != "row" and isinstance(peek_next_op(), Bias): + bias = pop_next_op() + + # Check if next op is reduce-scatter + reduce_scatter = None + if linear.tensor_parallel_mode is None and isinstance(peek_next_op(), ReduceScatter): + reduce_scatter = pop_next_op() + + # Check for invalid combinations + if reduce_scatter is None: + if linear.tensor_parallel_mode is None: + continue + if linear.tensor_parallel_size == 1: + continue + if linear.tensor_parallel_mode == "row" and bias is not None: + continue + else: + if linear.tensor_parallel_mode is not None: + continue + if reduce_scatter.process_group_size == 1: + continue + + # Replace window with fused op + op = UserbuffersBackwardLinear( + linear=linear, + bias=bias, + reduce_scatter=reduce_scatter, + ) + basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] + window = [(op, basic_op_idxs)] + + # Return list of ops + out_reversed.extend(reversed(window)) + out = out_reversed + out.reverse() + return out diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py new file mode 100644 index 0000000000..a1b0ca6a9e --- /dev/null +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -0,0 +1,597 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Linear layer forward with Userbuffers communication.""" + +from __future__ import annotations +from collections.abc import Iterable +from typing import Any, Optional + +import torch + +from transformer_engine_torch import CommOverlapAlgo +from ...cpp_extensions import fp8_gemm, gemm +from ...distributed import get_distributed_world_size +from ...float8_tensor import Float8Tensor +from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype +from ...module.base import get_ub, get_workspace +from ...utils import canonicalize_device, canonicalize_dtype +from ..basic import BasicLinear, Bias, ReduceScatter +from ..op import ( + BasicOperation, + FusedOperation, + FusibleOperation, + OperationContext, +) +from .._common import ( + convert_tensor, + get_fp8_meta_from_fp8_tensor, + is_float8_tensor, + reshape, +) + + +class UserbuffersForwardLinear(FusedOperation): + """Linear forward implementation using Userbuffers + + This operation is equivalent to a linear operation's forward pass, + but it uses Userbuffers to overlap tensor-parallel communication + with compute. + + """ + + def __init__( + self, + *, + linear: BasicLinear, + bias: Optional[Bias], + reduce_scatter: Optional[ReduceScatter], + ) -> None: + + # Basic operations that comprise this fused operation + op_idxs = {"linear": 0, "bias": None, "reduce_scatter": None} + ops = [linear] + if bias is not None: + op_idxs["bias"] = len(ops) + ops.append(bias) + if reduce_scatter is not None: + op_idxs["reduce_scatter"] = len(ops) + ops.append(reduce_scatter) + + # Initialize base class + super().__init__(ops) + + # Index of each basic operations + self._op_idxs: dict[str, Optional[int]] = op_idxs + + # Tensor parallelism configuration + self.tensor_parallel_mode: Optional[str] + self.tensor_parallel_group: Optional[torch.distributed.ProcessGroup] + self.tensor_parallel_size: int + self.sequence_parallel: bool + if reduce_scatter is None: + self.tensor_parallel_mode = linear.tensor_parallel_mode + self.tensor_parallel_group = linear.tensor_parallel_group + self.tensor_parallel_size = linear.tensor_parallel_size + self.sequence_parallel = linear.sequence_parallel + else: + self.tensor_parallel_mode = "row" + self.tensor_parallel_group = reduce_scatter.process_group + self.tensor_parallel_size = reduce_scatter.process_group_size + self.sequence_parallel = True + + @staticmethod + def _functional_forward( + input: torch.Tensor, # pylint: disable=redefined-builtin + weight: torch.Tensor, + *, + bias: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + tensor_parallel_mode: Optional[str] = None, + tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + tensor_parallel_size: Optional[int] = None, + sequence_parallel: bool = False, + with_fp8_compute: bool = False, + input_fp8_meta: Optional[dict[str, Any]] = None, + weight_fp8_meta: Optional[dict[str, Any]] = None, + output_fp8_meta: Optional[dict[str, Any]] = None, + ub_comm_name: str, + ) -> tuple[torch.Tensor, dict]: + """Functional API for forward pass + + Parameters + ---------- + input: torch.Tensor + Input tensor + weight: torch.Tensor + Weight tensor + bias: torch.Tensor, optional + Bias tensor + device: torch.device, default = default CUDA device + Tensor device + dtype: torch.dtype, default = default dtype + Tensor datatype + tensor_parallel_mode: {`None`, "column", "row"}, default = `None` + Mode for tensor parallelism + tensor_parallel_group: torch.distributed.ProcessGroup, default = world group + Process group for tensor parallelism + sequence_parallel: bool, default = `False` + Whether to apply sequence parallelism together with tensor + parallelism, i.e. distributing input or output tensors + along outer dimension (sequence or batch dim) when not + distributing along inner dimension (embedding dim) + with_fp8_compute: bool, default = `False` + Whether to perform compute in FP8 + input_fp8_meta: dict, optional + FP8 metadata for casting input tensor to FP8. Required for + FP8 compute if input is not already in FP8. + weight_fp8_meta: dict, optional + FP8 metadata for casting weight tensor to FP8. Required for + FP8 compute if weight is not already in FP8. + output_fp8_meta: dict, optional + FP8 metadata for casting output tensor to FP8 + ub_comm_name: str + Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is + used to access the corresponding Userbuffers communicators + (e.g. "qkv_fprop"). + + Returns + ------- + torch.Tensor + Output tensor + dict + Extra output tensors. "input" is the input tensor, + possibly cast and reshaped from the provided input tensor. + + """ + + # Check device + if device is None: + device = weight.device + device = canonicalize_device(device) + if device.type != "cuda": + raise ValueError(f"Only CUDA devices are supported (got {device})") + + # Check datatype + if dtype is None: + dtype = weight.dtype + dtype = canonicalize_dtype(dtype) + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") + + # Input tensor dims + input_dims = tuple(input.size()) + weight_dims = tuple(weight.size()) + if len(weight_dims) != 2: + raise ValueError(f"Weight tensor is not 2D (shape={weight_dims})") + if len(input_dims) == 0 or weight_dims[1] != input_dims[-1]: + raise ValueError( + f"Input tensor (shape={input_dims}) " + f"and weight tensor (shape={weight_dims}) " + "are not compatible" + ) + + # Output tensor dims + output_dims = list(input_dims) + output_dims[0] = -1 + output_dims[-1] = weight_dims[0] + + # Check tensor parallel group + if tensor_parallel_size is None: + tensor_parallel_size = get_distributed_world_size(tensor_parallel_group) + if tensor_parallel_size == 1: + tensor_parallel_mode = None + if tensor_parallel_mode not in ("column", "row"): + raise RuntimeError( + "Invalid configuration for Userbuffers " + f"({tensor_parallel_size=}, {tensor_parallel_mode=})" + ) + if not sequence_parallel: + raise RuntimeError(f"Invalid configuration for Userbuffers ({sequence_parallel=})") + + # Check if FP8 is enabled + if with_fp8_compute: + if input_fp8_meta is None and not is_float8_tensor(input): + raise ValueError("No FP8 metadata was provided for casting input to FP8") + if weight_fp8_meta is None and not is_float8_tensor(weight): + raise ValueError("No FP8 metadata was provided for casting weight to FP8") + else: + input_fp8_meta = None + weight_fp8_meta = None + output_fp8_meta = None + with_fp8_output = ( + with_fp8_compute and tensor_parallel_mode != "row" and output_fp8_meta is not None + ) + + # Get Userbuffers communicator + ub_comm = get_ub(ub_comm_name + "_fprop") + ub_local_buffer = ub_comm.get_ubuf_output(0) + ub_global_buffer = ub_comm.get_ubuf_output(1) + with_ub_all_gather = tensor_parallel_mode == "column" + with_ub_reduce_scatter = tensor_parallel_mode == "row" + + # Choose Userbuffers communication algorithm + ub_algo = None + if with_ub_all_gather: + if with_fp8_compute and ub_comm.is_atomic_gemm(): + ub_algo = CommOverlapAlgo.ATOMIC_GEMM_AG_P2P + else: + ub_algo = CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P + elif with_ub_reduce_scatter: + is_atomic_gemm = with_fp8_compute and ub_comm.is_atomic_gemm() + ub_algo = { + (True, True): CommOverlapAlgo.ATOMIC_GEMM_RS_P2P, + (True, False): CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P, + (False, True): CommOverlapAlgo.ATOMIC_GEMM_RS, + (False, False): CommOverlapAlgo.SPLIT_PIPELINED_RS, + }[(ub_comm.is_p2p_overlap(), is_atomic_gemm)] + else: + raise RuntimeError("Could not choose Userbuffers communication algorithm") + + # Cast input tensor to correct dtype + x_local = reshape( + input, + (-1, input_dims[-1]), + device=device, + dtype=dtype, + ) + if with_fp8_compute and not is_float8_tensor(x_local): + fp8_dtype = get_fp8_te_dtype( + input_fp8_meta["recipe"], + fprop_tensor=True, + ) + with_transpose_cache = weight.requires_grad + if tensor_parallel_mode == "column" and sequence_parallel: + with_transpose_cache = False + x_local = Float8Tensor.to_float8( + x_local, + fp8_meta=input_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + data=(ub_local_buffer if with_ub_all_gather else None), + with_transpose_cache=with_transpose_cache, + ) + elif not with_fp8_compute and is_float8_tensor(x_local): + if with_ub_all_gather: + x_local = ub_local_buffer.copy_(x_local) + else: + x_local = x_local.dequantize() + + # Initialize buffers for UB all-gather if needed + x = x_local + if with_ub_all_gather: + if with_fp8_compute: + x = Float8Tensor.make_like(x_local, data=ub_global_buffer) + if x_local._data.data_ptr() != ub_local_buffer.data_ptr(): + ub_local_buffer.copy_(x_local._data) + else: + x_local._data = torch.empty_like(x_local._data) + else: + x = ub_global_buffer + if x_local.data_ptr() != ub_local_buffer.data_ptr(): + ub_local_buffer.copy_(x_local) + else: + x_local = torch.empty_like(x_local) + + # Check weight tensor + w = convert_tensor( + weight, + device=device, + dtype=dtype, + memory_format=torch.contiguous_format, + ) + if with_fp8_compute and not is_float8_tensor(w): + fp8_dtype = get_fp8_te_dtype( + weight_fp8_meta["recipe"], + fprop_tensor=True, + ) + w = Float8Tensor.to_float8( + w, + fp8_meta=weight_fp8_meta, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + ) + elif not with_fp8_compute and is_float8_tensor(w): + w = w.dequantize() + + # Check bias tensor + b = None + if bias is not None: + b = convert_tensor( + bias, + device=device, + dtype=dtype, + memory_format=torch.contiguous_format, + ) + + # Construct output tensor + y = None + y_local = None + if with_ub_reduce_scatter: + # Initialize buffers for UB reduce-scatter + if with_fp8_output: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) + fp8_dtype = get_fp8_te_dtype( + output_fp8_meta["recipe"], + fprop_tensor=True, + ) + y = Float8Tensor( + data=ub_global_buffer, + fp8_meta=output_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + fp8_scale_inv=output_fp8_meta[fp8_meta_key].scale_inv[0], + dtype=dtype, + ) + ub_comm.set_ubuf_scale_inv(y._scale_inv) + else: + y = ub_global_buffer + y_local = torch.empty( + (x.size(0) // tensor_parallel_size, weight_dims[0]), + dtype=dtype, + device=device, + ) + else: + # Allocate output tensor + if with_fp8_output: + fp8_dtype = get_fp8_te_dtype( + output_fp8_meta["recipe"], + fprop_tensor=True, + ) + data = torch.empty( + (x.size(0), weight_dims[0]), + dtype=torch.uint8, + device=device, + ) + y = Float8Tensor( + data=data, + fp8_meta=output_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + dtype=dtype, + ) + else: + y = torch.empty( + (x.size(0), weight_dims[0]), + dtype=dtype, + device=device, + ) + y_local = y + + # Perform GEMM + if with_fp8_compute: + kwargs = { + "out": y, + "bias": b, + "use_bias": (b is not None), + "use_split_accumulator": False, + "ub_algo": ub_algo, + "ub": ub_comm, + } + if with_ub_all_gather: + kwargs["extra_output_tensor"] = x_local._data + if with_ub_reduce_scatter: + kwargs["extra_output_tensor"] = y_local + if with_fp8_output: + fp8_meta, fp8_meta_index = get_fp8_meta_from_fp8_tensor(y) + kwargs.update( + { + "out": y._data, + "out_index": fp8_meta_index, + "fp8_meta_tensor": fp8_meta, + "D_dtype": y._fp8_dtype, + } + ) + fp8_gemm( + w._data, + w._scale_inv, + 0, + w._fp8_dtype, + x._data, + x._scale_inv, + 0, + x._fp8_dtype, + y.dtype, + get_workspace(), + **kwargs, + ) + else: + kwargs = { + "out": y, + "bias": b, + "use_bias": (b is not None), + "ub_algo": ub_algo, + "ub": ub_comm, + } + if with_ub_all_gather: + kwargs["extra_output_tensor"] = x_local + if with_ub_reduce_scatter: + kwargs["extra_output_tensor"] = y_local + gemm(w, x, y.dtype, get_workspace(), **kwargs) + + # Reshape output tensor + out = reshape(y_local, output_dims) + + # Return cast tensors + extra_outputs = {"input": x_local, "weight": w} + return out, extra_outputs + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + basic_op_prev_ops: list[Optional[BasicOperation]], + basic_op_next_ops: list[Optional[BasicOperation]], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + + # Get basic operations + idx = self._op_idxs["linear"] + linear_op = self.basic_ops[idx] + linear_op_ctx = basic_op_ctxs[idx] + bias_op = None + bias = None + if self._op_idxs["bias"] is not None: + idx = self._op_idxs["bias"] + bias_op = self.basic_ops[idx] + bias = bias_op.bias + if basic_op_kwargs[idx]: + raise ValueError("Bias operation forward does not expect keyword arguments") + + # FP8 metadata + with_fp8_compute = FP8GlobalStateManager.is_fp8_enabled() + input_fp8_meta = None + weight_fp8_meta = None + output_fp8_meta = None + grad_output_fp8_meta = None + grad_input_fp8_meta = None + if with_fp8_compute: + input_fp8_meta = linear_op.get_fp8_meta("input") + weight_fp8_meta = linear_op.get_fp8_meta("param") + next_op = basic_op_next_ops[-1] + if next_op is not None and next_op.num_fp8_scales("input") > 0: + output_fp8_meta = next_op.get_fp8_meta("input") + grad_output_fp8_meta = linear_op.get_fp8_meta("grad_output") + prev_op = basic_op_prev_ops[0] + if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0: + grad_input_fp8_meta = prev_op.get_fp8_meta("grad_output") + + # Get autocast dtype if needed + dtype = None + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + + # Userbuffers options + if linear_op._userbuffers_options is None: + raise RuntimeError("Linear op is missing dict for Userbuffers options") + + # Linear forward + output, extra_outputs = UserbuffersForwardLinear._functional_forward( + input=input_, + weight=linear_op.weight, + bias=bias, + device=linear_op.device, + dtype=dtype, + tensor_parallel_mode=self.tensor_parallel_mode, + tensor_parallel_group=self.tensor_parallel_group, + tensor_parallel_size=self.tensor_parallel_size, + sequence_parallel=self.sequence_parallel, + with_fp8_compute=with_fp8_compute, + input_fp8_meta=input_fp8_meta, + weight_fp8_meta=weight_fp8_meta, + output_fp8_meta=output_fp8_meta, + ub_comm_name=linear_op._userbuffers_options["comm_name"], + ) + x_local = extra_outputs["input"] + + # Save state for backward pass + linear_op_ctx.save_for_backward(x_local) + linear_op_ctx.with_fp8_compute = with_fp8_compute + linear_op_ctx.weight_fp8_meta = weight_fp8_meta + linear_op_ctx.grad_output_fp8_meta = grad_output_fp8_meta + linear_op_ctx.grad_input_fp8_meta = grad_input_fp8_meta + linear_op_ctx.dtype = dtype + linear_op_ctx.input_dims = input_.size() + linear_op_ctx.input_requires_grad = input_.requires_grad + linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad + linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None + + return output, [() for _ in range(len(self.basic_ops))] + + +def fuse_userbuffers_forward_linear( + ops: list[tuple[FusibleOperation, list[int]]], +) -> list[tuple[FusibleOperation, list[int]]]: + """Substitute linear operations with Userbuffers implementation + + Parameters + ---------- + ops: list of tuples + Forward pass operations and the indices of the corresponding + basic operations. + + Returns + ------- + ops: list of tuples + Updated forward pass operations + + """ + + # Return immediately if environment is not distributed + if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: + return ops + + # Sliding window in list of ops + window = [] + + def peek_next_op() -> Optional[FusibleOperation]: + """Get next op in list of ops""" + nonlocal ops + if not ops: + return None + return ops[0][0] + + def pop_next_op() -> FusibleOperation: + """Remove next op from list of ops and add to sliding window""" + nonlocal ops, window + window.append(ops[0]) + ops = ops[1:] + return window[-1][0] + + # Scan through ops, fusing if possible + out = [] + while ops: + out.extend(window) + window.clear() + + # Check if next op is linear + next_op = pop_next_op() + if not isinstance(next_op, BasicLinear): + continue + linear = next_op + if linear._userbuffers_options is None: + continue + + # Check if next op is bias + bias = None + if linear.tensor_parallel_mode != "row" and isinstance(peek_next_op(), Bias): + bias = pop_next_op() + + # Check if next op is reduce-scatter + reduce_scatter = None + if linear.tensor_parallel_mode is None and isinstance(peek_next_op(), ReduceScatter): + reduce_scatter = pop_next_op() + + # Check for invalid combinations + if reduce_scatter is None: + if linear.tensor_parallel_mode is None: + continue + if linear.tensor_parallel_size == 1: + continue + if linear.tensor_parallel_mode == "row" and bias is not None: + continue + else: + if linear.tensor_parallel_mode is not None: + continue + if reduce_scatter.process_group_size == 1: + continue + + # Replace window with fused op + op = UserbuffersForwardLinear( + linear=linear, + bias=bias, + reduce_scatter=reduce_scatter, + ) + basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] + window = [(op, basic_op_idxs)] + + # Return list of ops + out.extend(window) + return out diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index a7c99c592d..8b2a04cff8 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -5,12 +5,12 @@ """Manager class for a pipeline of fusible operations.""" from __future__ import annotations +from collections.abc import Callable from typing import Any, Optional import torch from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.graph import is_graph_capturing from transformer_engine.pytorch.ops.op import ( BasicOperation, FusibleOperation, @@ -20,6 +20,8 @@ fuse_backward_linear_add, fuse_forward_linear_bias_activation, fuse_forward_linear_bias_add, + fuse_userbuffers_backward_linear, + fuse_userbuffers_forward_linear, ) @@ -28,6 +30,24 @@ def _split_tuple(t: tuple, idx: int) -> tuple[tuple, tuple]: return t[:idx], t[idx:] +# Lazily imported function used in _is_graph_capturing +_is_graph_capturing_function: Optional[Callable[[], bool]] = None + + +def _is_graph_capturing() -> bool: + """Whether function is called within `make_graphed_callables` + + Avoid circular import with lazy import. + + """ + global _is_graph_capturing_function + if _is_graph_capturing_function is None: + from ..graph import is_graph_capturing + + _is_graph_capturing_function = is_graph_capturing + return _is_graph_capturing_function() + + class _OperationFuserAutogradFunction(torch.autograd.Function): """Autograd function for a pipeline of operations @@ -39,12 +59,13 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): # pylint: disable=unused-argument @staticmethod def forward( - func_ctx: torch.autograd.function.FunctionCtx, + func_ctx: Optional[torch.autograd.function.FunctionCtx], input_: torch.Tensor, forward_ops: list[tuple[FusibleOperation, list[int]]], backward_ops: list[tuple[FusibleOperation, list[int]]], basic_ops: list[BasicOperation], basic_op_kwargs: list[dict[str, Any]], + is_grad_enabled: bool, num_params: int, num_extra_inputs: int, *params_and_extra_inputs: torch.nn.Parameter, @@ -102,10 +123,24 @@ def forward( # Apply forward ops x = input_ - requires_grad = x.requires_grad + requires_grad = is_grad_enabled and x.requires_grad extra_outputs = [None for _ in range(len(basic_ops))] for op, basic_op_idxs in forward_ops: + # Check if backward op is required + if is_grad_enabled: + if not requires_grad: + requires_grad = any(param.requires_grad for param in op.parameters()) + if not requires_grad: + requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs) + for idx in basic_op_idxs: + basic_op_ctxs[idx].requires_grad = requires_grad + if requires_grad != x.requires_grad: + if requires_grad: + x.requires_grad_() + else: + x = x.detach() + # Forward op extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs] prev_ops = [basic_ops[idx - 1] if idx > 0 else None for idx in basic_op_idxs] @@ -120,18 +155,12 @@ def forward( basic_op_next_ops=next_ops, basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs], ) + x.requires_grad_(requires_grad=requires_grad) for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs): + for y in ys: + y.requires_grad_(requires_grad=requires_grad) extra_outputs[idx] = ys - # Check if backward op is required - if not requires_grad: - requires_grad = any(param.requires_grad for param in op.parameters()) - if not requires_grad: - requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs) - for idx in basic_op_idxs: - basic_op_ctxs[idx]._requires_grad = requires_grad - x.requires_grad_(requires_grad=requires_grad) - # Flatten list of extra outputs extra_outputs_flat = [] for idx, ys in enumerate(extra_outputs): @@ -145,25 +174,28 @@ def forward( ) extra_outputs_flat.extend(ys) - # Flatten list of saved tensors - to_save = [] - for ctx in basic_op_ctxs: - range_start = len(to_save) - if ctx.to_save is not None: - to_save.extend(ctx.to_save) - range_end = len(to_save) - ctx.to_save = None - ctx._saved_tensors_range = (range_start, range_end) - func_ctx.save_for_backward(*to_save) - - # Other context for backward pass - func_ctx.backward_ops = backward_ops - func_ctx.basic_ops = basic_ops - func_ctx.basic_op_ctxs = basic_op_ctxs - func_ctx.num_params = num_params - func_ctx.num_extra_inputs = num_extra_inputs - func_ctx.num_extra_outputs = len(extra_outputs_flat) - func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() + # Save context for backward pass + if is_grad_enabled: + + # Flatten list of saved tensors + to_save = [] + for ctx in basic_op_ctxs: + range_start = len(to_save) + if ctx.to_save is not None: + to_save.extend(ctx.to_save) + range_end = len(to_save) + ctx.to_save = None + ctx._saved_tensors_range = (range_start, range_end) + func_ctx.save_for_backward(*to_save) + + # Other context + func_ctx.backward_ops = backward_ops + func_ctx.basic_ops = basic_ops + func_ctx.basic_op_ctxs = basic_op_ctxs + func_ctx.num_params = num_params + func_ctx.num_extra_inputs = num_extra_inputs + func_ctx.num_extra_outputs = len(extra_outputs_flat) + func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() if extra_outputs_flat: return x, *extra_outputs_flat @@ -206,7 +238,7 @@ def backward( for op, basic_op_idxs in backward_ops: # Stop if no more gradients are required - if all(not basic_op_ctxs[idx]._requires_grad for idx in basic_op_idxs): + if all(not basic_op_ctxs[idx].requires_grad for idx in basic_op_idxs): dx = None break @@ -255,7 +287,7 @@ def backward( grad_extra_inputs_flat.extend(dxs) # Update FP8 scaling factors - if func_ctx.is_first_module and not is_graph_capturing(): + if func_ctx.is_first_module and not _is_graph_capturing(): FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) return ( @@ -264,6 +296,7 @@ def backward( None, # backward_ops None, # basic_ops None, # basic_op_kwargs + None, # is_grad_enabled None, # num_params None, # num_extra_inputs *grad_params_flat, @@ -318,6 +351,7 @@ def _fuse_forward_ops( ops: list[tuple[FusibleOperation, list[int]]], ) -> list[tuple[FusibleOperation, list[int]]]: """Attempt to fuse operations in forward pass""" + ops = fuse_userbuffers_forward_linear(ops) ops = fuse_forward_linear_bias_add(ops) ops = fuse_forward_linear_bias_activation(ops) return ops @@ -328,6 +362,7 @@ def _fuse_backward_ops( ops: list[tuple[FusibleOperation, list[int]]], ) -> list[tuple[FusibleOperation, list[int]]]: """Attempt to fuse operations in backward pass""" + ops = fuse_userbuffers_backward_linear(ops) ops = fuse_backward_linear_add(ops) return ops @@ -355,14 +390,23 @@ def __call__( params = [param for op in self._basic_ops for param in op.parameters()] # Fuser forward pass - return _OperationFuserAutogradFunction.apply( + is_grad_enabled = torch.is_grad_enabled() + if is_grad_enabled: + forward_func = _OperationFuserAutogradFunction.apply + args = [] + else: + forward_func = _OperationFuserAutogradFunction.forward + args = [None] + args += ( input, self._forward_ops, self._backward_ops, self._basic_ops, basic_op_kwargs, + is_grad_enabled, len(params), self._num_extra_inputs, *params, *extra_inputs, ) + return forward_func(*args) diff --git a/transformer_engine/pytorch/ops/linear.py b/transformer_engine/pytorch/ops/linear.py index daa5a6952e..68472f171a 100644 --- a/transformer_engine/pytorch/ops/linear.py +++ b/transformer_engine/pytorch/ops/linear.py @@ -133,6 +133,38 @@ def __init__( # Initialize base class super().__init__(ops) - # Register parameters - self.register_parameter("weight", self.basic_ops[0].weight) - self.register_parameter("bias", self.basic_ops[1].bias if bias else None) + self._has_bias: bool = bias + + @property + def weight(self) -> torch.nn.Parameter: + """Weight tensor + + Parameter is owned by `BasicLinear` operation. + + """ + return self.basic_ops[0].weight + + @weight.setter + def weight(self, value: Optional[torch.nn.Parameter]) -> None: + self.basic_ops[0].weight = value + + @property + def bias(self) -> Optional[torch.nn.Parameter]: + """Bias tensor + + Parameter is owned by `Bias` operation. + + """ + if self._has_bias: + return self.basic_ops[1].bias + return None + + @bias.setter + def bias(self, value: Optional[torch.nn.Parameter]) -> None: + if self._has_bias: + self.basic_ops[1].bias = value + elif value is not None: + raise ValueError( + "Attempted to set bias parameter in Linear operation " + "that does not have bias enabled" + ) diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 75905ad854..04a66b7942 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -8,16 +8,18 @@ import abc from collections.abc import Iterable import dataclasses +import pickle from typing import Any, Optional import torch import transformer_engine_torch as tex from transformer_engine.pytorch.fp8 import ( + DelayedScaling, FP8GlobalStateManager, get_default_fp8_recipe, ) -from ._common import canonicalize_device, is_float8_tensor +from ._common import canonicalize_device @dataclasses.dataclass @@ -41,7 +43,7 @@ class OperationContext: _saved_tensors_range: Optional[tuple[int, int]] = None # Whether backward pass is required - _requires_grad: bool = False + requires_grad: bool = True def save_for_backward(self, *tensors: Optional[torch.Tensor]) -> None: """Register tensors to be saved for the backward function @@ -231,25 +233,37 @@ def _make_meta( } @classmethod - def _maybe_update_fp8_meta(cls, fp8_meta: Optional[dict[str, Any]]) -> None: + def _maybe_update_fp8_meta( + cls, + fp8_meta: Optional[dict[str, Any]], + *, + fp8_recipe: Optional[DelayedScaling] = None, + ) -> None: if fp8_meta is None: return - # Update FP8 recipe and communication group - recipe = FP8GlobalStateManager.get_fp8_recipe() - fp8_meta["recipe"] = recipe + # Update FP8 recipe + if fp8_recipe is None: + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + fp8_meta["recipe"] = fp8_recipe + + # Update FP8 communication group fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() # Adjust amax history length if needed - amax_history_len = recipe.amax_history_len + amax_history_len = fp8_recipe.amax_history_len for is_forward in (True, False): - key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward) - if key not in fp8_meta: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward) + if fp8_meta_key not in fp8_meta: continue - meta = fp8_meta[key] + meta = fp8_meta[fp8_meta_key] curr_len = meta.amax_history.size(0) + + # Nothing to be done if amax history is already correct if curr_len == amax_history_len: continue + + # Reallocate amax history with torch.no_grad(): if curr_len > amax_history_len: meta.amax_history = meta.amax_history[:amax_history_len].clone() @@ -259,6 +273,21 @@ def _maybe_update_fp8_meta(cls, fp8_meta: Optional[dict[str, Any]]) -> None: pad=(0, 0, 0, amax_history_len - curr_len), ) + # Update global buffers for amax reductions + buffer_info_key = FP8GlobalStateManager.get_buffer_info() + if buffer_info_key in fp8_meta: + fwd_pos, fwd_key, bwd_pos, bwd_key = fp8_meta[buffer_info_key] + for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)): + assert ( + buffer_key in FP8GlobalStateManager.global_amax_history_buffer + ), "TE internal error during amax history change." + FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = fp8_meta[ + fp8_meta_key + ].amax_history[0] + FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = fp8_meta[ + fp8_meta_key + ].amax_history + def get_fp8_meta(self, mode: str) -> Optional[dict[str, Any]]: """FP8 metadata @@ -272,11 +301,67 @@ def get_fp8_meta(self, mode: str) -> Optional[dict[str, Any]]: self._fp8_metas = self._make_fp8_metas() return self._fp8_metas[mode] - def pre_forward(self) -> None: + @torch.no_grad() + def _save_fp8_metas(self) -> Optional[dict[str, Any]]: + """Create copies of tensors in FP8 metadata + + Tensor copies can be loaded with _load_fp8_metas. + + """ + if self._fp8_metas is None: + return None + out = {} + for mode, fp8_meta in self._fp8_metas.items(): + if fp8_meta is None: + continue + out[mode] = {} + for is_forward in (True, False): + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward) + if fp8_meta_key not in fp8_meta: + continue + out[mode][fp8_meta_key] = ( + fp8_meta[fp8_meta_key].scale.clone(), + fp8_meta[fp8_meta_key].scale_inv.clone(), + fp8_meta[fp8_meta_key].amax_history.clone(), + ) + return out + + @torch.no_grad() + def _load_fp8_metas(self, fp8_metas: Optional[dict[str, Any]]) -> None: + """Update FP8 metadata with saved tensor copies + + Tensor copies should be generated with _save_fp8_metas. + + """ + assert (self._fp8_metas is None) == ( + fp8_metas is None + ), "Saved FP8 metadata does not match operation's FP8 metadata" + if fp8_metas is None: + return + for mode, fp8_meta in fp8_metas.items(): + assert ( + mode in self._fp8_metas + ), f"Found an unexpected key ({mode=}) in saved FP8 metadata" + for fp8_meta_key, tensors in fp8_meta.items(): + assert ( + fp8_meta_key in self._fp8_metas[mode] + ), f"Found an unexpected key ({mode=}, {fp8_meta_key=}) in saved FP8 metadata" + scale, scale_inv, amax_history = tensors + self._fp8_metas[mode][fp8_meta_key].scale.copy_(scale) + self._fp8_metas[mode][fp8_meta_key].scale_inv.copy_(scale_inv) + self._fp8_metas[mode][fp8_meta_key].amax_history.copy_(amax_history) + + def pre_forward( + self, + *, + fp8_enabled: Optional[bool] = None, + fp8_recipe: Optional[DelayedScaling] = None, + ) -> None: """Preprocessing before forward pass""" # Initialize FP8 metadata if needed - fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() + if fp8_enabled is None: + fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() if fp8_enabled: # Construct FP8 metadata if needed @@ -285,7 +370,7 @@ def pre_forward(self) -> None: # Make sure FP8 metadata matches FP8 autocast context for fp8_meta in self._fp8_metas.values(): - self._maybe_update_fp8_meta(fp8_meta) + self._maybe_update_fp8_meta(fp8_meta, fp8_recipe=fp8_recipe) # Register FP8 metadata for amax and scale update if not FP8GlobalStateManager.fp8_graph_capturing(): @@ -294,10 +379,8 @@ def pre_forward(self) -> None: self.get_fp8_meta("input"), ) if self.num_fp8_scales("param"): - fp8_params = list(filter(is_float8_tensor, self.parameters())) FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( self.get_fp8_meta("param"), - fp8_weights=(fp8_params if fp8_params else None), ) if self.num_fp8_scales("grad_output"): FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( @@ -420,6 +503,161 @@ def forward( basic_op_kwargs=[kwargs], ) + def get_extra_state(self) -> torch.Tensor: + """Serialize extra state + + Contains metadata for FP8 casting. + + """ + + # This implementation is working around a few issues: + # + # (1) PyTorch's "extra state" infrastructure might be able to + # support any picklable type, but they make no guarantees. + # It seems that ONNX export experiences issues with + # non-tensor extra state. + # (2) PyTorch's checkpointing infrastructure does not remap + # devices for "extra state" like it does for "state dict". + # Thus, we want to avoid putting extra state on the GPU + # since it may be loaded on the wrong device. + # (3) The extra state consists of many small tensors. If we + # want to copy them all to CPU, then we need to avoid the + # overhead of many GPU-CPU memory transfers. + # + # See: https://github.com/NVIDIA/TransformerEngine/pull/351 + # See: https://github.com/NVIDIA/TransformerEngine/pull/363 + + # Return immediately if op has no FP8 state + has_fp8_state = any( + self.num_fp8_scales(mode) > 0 for mode in ("input", "param", "grad_output") + ) + if not has_fp8_state: + return torch.Tensor() + + def to_cpu(src: torch.Tensor) -> torch.Tensor: + """Helper function to make CPU copy of tensor + + Memory transfer is asynchronous w.r.t. host, so GPU should + be synchronized before using result. + + """ + dst = torch.empty_like(src, device="cpu") + dst.copy_(src, non_blocking=True) + return dst + + # Store FP8 state + state = {} + for mode in ("input", "param", "grad_output"): + + # Get state for a given FP8 tensor + if self.num_fp8_scales(mode) == 0: + state[mode] = None + continue + fp8_meta = self.get_fp8_meta(mode) + if fp8_meta is None: + continue + state[mode] = {} + + # Store tensors + if "scaling_fwd" in fp8_meta: + state[mode]["scale_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale) + state[mode]["scale_inv_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale_inv) + state[mode]["amax_history_fwd"] = to_cpu(fp8_meta["scaling_fwd"].amax_history) + if "scaling_bwd" in fp8_meta: + state[mode]["scale_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale) + state[mode]["scale_inv_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale_inv) + state[mode]["amax_history_bwd"] = to_cpu(fp8_meta["scaling_bwd"].amax_history) + + # Store other picklable items + extra = {} + for key, val in fp8_meta.items(): + if key == "buffer_index_and_autocast_key": + continue + if not isinstance(val, (bool, int, float, str, tuple, list)): + continue + extra[key] = val + state[mode]["extra_fp8_variables"] = extra + + # Serialize state into byte tensor + torch.cuda.synchronize() + state_serialized = bytearray(pickle.dumps(state)) + state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8) + return state_serialized + + def set_extra_state(self, state: Optional[torch.Tensor]) -> None: + """Load extra state""" + if state is None or state.numel() == 0: + return + + # Deserialize state from byte tensor + state = pickle.loads(state.detach().numpy(force=True).tobytes()) + if state is None: + return + + def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: + """Helper function to copy tensor from CPU + + Memory transfer is asynchronous w.r.t. host, so GPU should + be synchronized before using result. + + """ + if src.size() != dst.size(): + dst.data = torch.empty(src.size(), dtype=dst.dtype, device=dst.device) + dst.copy_(src, non_blocking=True) + + # Load FP8 state + for mode in ("input", "param", "grad_output"): + + # Get state for a given FP8 tensor + if mode not in state: + continue + if self.num_fp8_scales(mode) == 0: + continue + fp8_meta = self.get_fp8_meta(mode) + if fp8_meta is None: + continue + + # Load extra state + fp8_meta.update(state[mode]["extra_fp8_variables"]) + if "amax_history_fwd" in state[mode]: + fp8_meta["recipe"].amax_history_len = state[mode]["amax_history_fwd"].size(0) + elif "amax_history_bwd" in state[mode]: + fp8_meta["recipe"].amax_history_len = state[mode]["amax_history_bwd"].size(0) + if "global_fp8_buffer_pos_fwd_recompute" in fp8_meta: + del fp8_meta["global_fp8_buffer_pos_fwd_recompute"] + + # Load tensors + fp8_meta = self.get_fp8_meta(mode) + if "scaling_fwd" in fp8_meta: + fp8_meta_fwd = fp8_meta["scaling_fwd"] + copy_tensor(state[mode]["scale_fwd"], fp8_meta_fwd.scale) + copy_tensor(state[mode]["scale_inv_fwd"], fp8_meta_fwd.scale_inv) + copy_tensor(state[mode]["amax_history_fwd"], fp8_meta_fwd.amax_history) + if "scaling_bwd" in fp8_meta: + fp8_meta_bwd = fp8_meta["scaling_bwd"] + copy_tensor(state[mode]["scale_bwd"], fp8_meta_bwd.scale) + copy_tensor(state[mode]["scale_inv_bwd"], fp8_meta_bwd.scale_inv) + copy_tensor(state[mode]["amax_history_bwd"], fp8_meta_bwd.amax_history) + + # Finish CPU-GPU memory transfers + torch.cuda.synchronize() + + def _load_from_state_dict(self, *args, **kwargs) -> None: + """Load state""" + + # In the base PyTorch module class, the extra state is loaded + # _after_ the parameters. However, copying values into FP8 + # parameters requires an FP8 cast, which uses a scaling factor + # from the operation's FP8 metadata. The FP8 metadata is + # included in the operation's extra state, so we need to + # manually load the extra state before loading parameters. + + state_dict, prefix = args[0], args[1] + extra_state_key = prefix + torch.nn.modules.module._EXTRA_STATE_KEY_SUFFIX + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + super()._load_from_state_dict(*args, **kwargs) + class FusedOperation(FusibleOperation): """Compound tensor operation supported by the operation fuser diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index 191c98745d..93f6191dfe 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -3,11 +3,15 @@ # See LICENSE for license information. """Fused Adam optimizer.""" +from copy import deepcopy +from itertools import chain + import torch import transformer_engine_torch as tex from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from .multi_tensor_apply import multi_tensor_applier +from ..float8_tensor import Float8Tensor def get_fp8_meta(fp8_tensor): @@ -68,11 +72,28 @@ class FusedAdam(torch.optim.Optimizer): method is called. (default: True) capturable (bool, optional): whether to use the version of the optimizer that can be used with CUDA Graphs. (default: False) - master_weights (list of torch.Tensor, optional): master weights to use - for mixed precision training. If provided, the optimizer will update - the master weights and then cast the master weights to the model weights. - If not provided, the optimizer will update the model weights directly. - (default: None) + master_weights (bool, optional): whether to maintain FP32 master weights + in the optimizer with FP16/BF16 mixed precision training. + (default: False) + master_weight_dtype (torch.dtype, optional): The dtype of master weights. + If master_weights is False, this will be ignored. It can be one of + [torch.float32, torch.float16]. If it's not torch.float32, the optimizer + will create a FP32 scalar scaling factor to ensure precision. + (default: torch.float32) + exp_avg_dtype (torch.dtype, optional): The dtype of exp_avg. It can be + one of [torch.float32, torch.float16, torch.uint8], where torch.uint8 + represents FP8. If it's not torch.float32, the optimizer will create + a FP32 scalar scaling factor to ensure precision. + (default: torch.float32) + exp_avg_sq_dtype (torch.dtype, optional): The dtype of exp_avg_sq. It + can be one of [torch.float32, torch.float16, torch.uint8], where + torch.uint8 represents FP8. If it's not torch.float32, the optimizer + will create a FP32 scalar scaling factor to ensure precision. + (default: torch.float32) + use_decoupled_grad (bool, optional): Whether to use ".decoupled_grad" + instead of ".grad" for reading gradients. It's useful when the dtypes + of grad and param are different. + (default: False) .. _Adam - A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 @@ -92,12 +113,36 @@ def __init__( amsgrad=False, set_grad_none=True, capturable=False, - master_weights=None, + master_weights=False, + master_weight_dtype=torch.float32, + exp_avg_dtype=torch.float32, + exp_avg_sq_dtype=torch.float32, + use_decoupled_grad=False, ): if amsgrad: raise RuntimeError("FusedAdam does not support the AMSGrad variant.") + # Add constraints to dtypes of states. + if master_weights and master_weight_dtype not in [torch.float32, torch.float16]: + raise RuntimeError("FusedAdam only supports fp32/fp16 master weights.") + if exp_avg_dtype not in [torch.float32, torch.float16, torch.uint8]: + raise RuntimeError("FusedAdam only supports fp32/fp16/fp8 exp_avg.") + if exp_avg_sq_dtype not in [torch.float32, torch.float16, torch.uint8]: + raise RuntimeError("FusedAdam only supports fp32/fp16/fp8 exp_avg_sq.") + + # Currently, capturable mode only supports fp32 master weights and optimizer states. + # The reason is, if the master weights or optimizer states are not in fp32 dtype, + # they will be copied to temporary fp32 buffers first. These fp32 buffers are then + # used as inputs for the kernel. Consequently, the pointer for earch `.step()` differs, + # making CUDA Graph inapplicable in this scenario. + if capturable and master_weights and master_weight_dtype != torch.float32: + raise RuntimeError("Capturable mode only supports fp32 master weights.") + if capturable and exp_avg_dtype != torch.float32: + raise RuntimeError("Capturable mode only supports fp32 exp_avg.") + if capturable and exp_avg_sq_dtype != torch.float32: + raise RuntimeError("Capturable mode only supports fp32 exp_avg_sq") + # If the optimizer is capturable then LR should be a tensor (on GPU) lr = torch.tensor(lr, dtype=torch.float32) if capturable else lr defaults = { @@ -112,9 +157,6 @@ def __init__( self.set_grad_none = set_grad_none self.capturable = capturable - - if master_weights is not None: - assert isinstance(master_weights, list), "master_weights must be a list if provided" self.master_weights = master_weights if capturable: @@ -134,14 +176,208 @@ def __init__( self.multi_tensor_adam_capturable = tex.multi_tensor_adam_capturable self.multi_tensor_adam_capturable_master = tex.multi_tensor_adam_capturable_master + self.master_weight_dtype = master_weight_dtype + self.exp_avg_dtype = exp_avg_dtype + self.exp_avg_sq_dtype = exp_avg_sq_dtype + self.name_to_dtype_map = { + "exp_avg": self.exp_avg_dtype, + "exp_avg_sq": self.exp_avg_sq_dtype, + "master_param": self.master_weight_dtype, + } + self.dtype_to_range_map = { + torch.float16: torch.full( + [1], torch.finfo(torch.float16).max / 2.0, dtype=torch.float32 + ), + torch.uint8: torch.full([1], 448.0, dtype=torch.float32), + } + self._scales = {} + self.use_decoupled_grad = use_decoupled_grad + def zero_grad(self): # pylint: disable=missing-function-docstring - if self.set_grad_none: - for group in self.param_groups: - for p in group["params"]: + if not self.use_decoupled_grad and not self.set_grad_none: + super().zero_grad() + return + + for group in self.param_groups: + for p in group["params"]: + if self.use_decoupled_grad and self.set_grad_none: + p.decoupled_grad = None + elif self.use_decoupled_grad and not self.set_grad_none: + p.decoupled_grad.zero_() + elif not self.use_decoupled_grad and self.set_grad_none: p.grad = None + + def _apply_scale(self, state_name, unscaled_state, scaled_state, scale): + """Apply scaling on `unscaled_state`. `scaled_state` and `scale` will be written inplace. + + Arguments: + state_name (string): Name of optimizer states, can be one of 'exp_avg', 'exp_avg_sq', + and 'master_param`. + unscaled_state (torch.Tensor): An unscaled high-precision tensor. + scaled_state (torch.Tensor): An scaled low-precision tensor. + scale (torch.Tensor): A FP32 tensor representing the scaling factor. + """ + assert unscaled_state.dtype == torch.float32 + dtype = self.name_to_dtype_map[state_name] + if dtype == torch.uint8: + assert isinstance(scaled_state, Float8Tensor) else: - super().zero_grad() + assert scaled_state.dtype == dtype + + max_range = self.dtype_to_range_map[dtype] + if max_range.device != scaled_state.device: + max_range = max_range.to(scaled_state.device) + self.dtype_to_range_map[scaled_state.dtype] = max_range + if unscaled_state.device != scaled_state.device: + unscaled_state = unscaled_state.to(scaled_state.device) + min_val, max_val = torch.aminmax(unscaled_state) + absmax = torch.maximum(-min_val, max_val) + absmax = absmax.to(dtype=torch.float32, device=unscaled_state.device) + torch.div(absmax, max_range, out=scale) + if isinstance(scaled_state, Float8Tensor): + scaled_state._scale_inv.copy_(scale) + scaled_state.copy_(unscaled_state) + else: + rscale = torch.where(scale > 0, scale.reciprocal(), 0.0) + unscaled_state.mul_(rscale) + scaled_state.copy_(unscaled_state) + + def get_unscaled_state(self, param, state_name): + """Return the unscaled state corresponding to the input `param` and `state_name`. + + Arguments: + param (torch.nn.Parameter): One of parameters in this optimizer. + state_name (string): Name of optimizer states, can be one of 'exp_avg', 'exp_avg_sq', + and 'master_param`. + """ + state = self.state[param] + dtype = self.name_to_dtype_map[state_name] + if dtype == torch.uint8: + assert isinstance(state[state_name], Float8Tensor) + unscaled = state[state_name].float() + elif dtype == torch.float16: + assert state[state_name].dtype == torch.float16 + unscaled = state[state_name].float() + unscaled.mul_(self._scales[param][state_name]) + elif dtype == torch.float32: + assert state[state_name].dtype == torch.float32 + unscaled = state[state_name] + else: + raise RuntimeError(f"Dtype of {state_name} can only be fp8/fp16/fp32.") + return unscaled + + def set_scaled_state(self, param, state_name, unscaled_state): + """Set the optimizer state. + + If the dtype of the corresponding optimizer state is not FP32, + it will do scaling automatically. + + Arguments: + param (torch.nn.Parameter): One of parameters in this optimizer. + state_name (string): Name of optimizer states, can be one of 'exp_avg', 'exp_avg_sq', + and 'master_param`. + unscaled_state (torch.Tensor): The original high-precision(FP32) state. + """ + assert unscaled_state.dtype == torch.float32 + state = self.state[param] + if state_name not in state: + self._initialize_state(param, state_name, False) + + dtype = self.name_to_dtype_map[state_name] + if dtype != torch.float32: + scale = self._scales[param] + self._apply_scale(state_name, unscaled_state, state[state_name], scale[state_name]) + else: + state[state_name].copy_(unscaled_state) + + def _initialize_state(self, param, state_name, zero_buffer: bool): + """Initialize one of the optimizer states according to `state_name`. + + Arguments: + param (torch.nn.Parameter): One of parameters in this optimizer. + state_name (string): Name of optimizer states, can be one of 'exp_avg', 'exp_avg_sq', + and 'master_param`. + zero_buffer (bool): Whether to initialize the optimizer state with zeros. + """ + dtype = self.name_to_dtype_map[state_name] + data = torch.empty_like(param, dtype=dtype) + if zero_buffer: + data.zero_() + + if dtype == torch.uint8: + self.state[param][state_name] = Float8Tensor( + data=data, + dtype=torch.float32, + fp8_scale_inv=torch.ones([1], dtype=torch.float32, device=param.device), + ) + else: + self.state[param][state_name] = data + + # Create scale if necessary. + if dtype != torch.float32: + if param not in self._scales: + self._scales[param] = {} + self._scales[param][state_name] = torch.ones( + [1], dtype=torch.float32, device=param.device + ) + + def initialize_state(self, param): + """Initialize optimizer states. + + Arguments: + param (torch.nn.Parameter): One of parameters in this optimizer. + """ + self._initialize_state(param, "exp_avg", zero_buffer=True) + self._initialize_state(param, "exp_avg_sq", zero_buffer=True) + if self.master_weights: + self._initialize_state(param, "master_param", zero_buffer=False) + self.set_scaled_state(param, "master_param", param.clone().detach().float()) + + def state_dict(self): + """Override the state_dict() of pytorch. Before returning the state_dict, cast all + non-fp32 states to fp32. + """ + state_dict = super().state_dict() + + groups = self.param_groups + saved_groups = deepcopy(state_dict["param_groups"]) + id_map = dict( + zip( + chain.from_iterable(g["params"] for g in saved_groups), + chain.from_iterable(g["params"] for g in groups), + ) + ) + for k, v in state_dict["state"].items(): + if k in id_map: + param = id_map[k] + new_v = {} + for name in v: + new_v[name] = self.get_unscaled_state(param, name) + state_dict["state"][k] = new_v + + return state_dict + + def load_state_dict(self, state_dict): + """Override the load_state_dict() of pytorch. Since pytorch's load_state_dict forces the + state to be the same dtype as param, We need to manully set the state again. + """ + super().load_state_dict(state_dict) + + groups = self.param_groups + saved_groups = deepcopy(state_dict["param_groups"]) + id_map = dict( + zip( + chain.from_iterable(g["params"] for g in saved_groups), + chain.from_iterable(g["params"] for g in groups), + ) + ) + for k, v in state_dict["state"].items(): + if k in id_map: + param = id_map[k] + self.state[param] = {} + for name in v: + self.set_scaled_state(param, name, v[name].float()) def step(self, closure=None, grad_scaler=None): """Performs a single optimization step. @@ -156,8 +392,6 @@ def step(self, closure=None, grad_scaler=None): if closure is not None: loss = closure() - master_param_idx = 0 - for group in self.param_groups: if len(group["params"]) == 0: continue @@ -196,6 +430,11 @@ def step(self, closure=None, grad_scaler=None): amaxes = [] scale_invs = [] + # Lists for scaling + unscaled_lists = {"exp_avg": [], "exp_avg_sq": [], "master_param": []} + scaled_lists = {"exp_avg": [], "exp_avg_sq": [], "master_param": []} + state_scales = {"exp_avg": [], "exp_avg_sq": [], "master_param": []} + # Only used when extra params include fp8 tensors. Otherwise, it doesn't matter what the out_dtype is. out_dtype = tex.DType.kFloat32 @@ -207,31 +446,29 @@ def step(self, closure=None, grad_scaler=None): # State initialization if len(state) == 0: - # Exponential moving average of gradient values - state["exp_avg"] = torch.zeros_like(p.data).float() - # Exponential moving average of squared gradient values - state["exp_avg_sq"] = torch.zeros_like(p.data).float() - # Master weights - if self.master_weights and p.dtype != torch.float32: - # model weights can be fp32/bf16/fp16/fp8 - # If it's fp32, it has no corresponding master weights - state["master_param"] = self.master_weights[master_param_idx] - master_param_idx += 1 - assert ( - state["master_param"].shape == p.shape - ), "Master weights shape must match model weights shape" - - p_master = state.get("master_param", None) - p_grad = p.grad - - if self.master_weights and p_master is not None and p_master.grad is not None: - p_grad = p_master.grad + self.initialize_state(p) + + if self.use_decoupled_grad: + p_grad = p.decoupled_grad if hasattr(p, "decoupled_grad") else None + else: + p_grad = p.grad if p_grad is None: continue if p_grad.data.is_sparse: raise RuntimeError("FusedAdam does not support sparse gradients.") + # Unscaling + unscaled_state = {} + for name in ["exp_avg", "exp_avg_sq", "master_param"]: + if name in state: + unscaled = self.get_unscaled_state(p, name) + unscaled_state[name] = unscaled + if self.name_to_dtype_map[name] != torch.float32: + unscaled_lists[name].append(unscaled) + scaled_lists[name].append(state[name]) + state_scales[name].append(self._scales[p][name]) + if isinstance(p, Float8Tensor): out_dtype = p._fp8_dtype p_fp8_model.append(p._data.data) @@ -240,26 +477,28 @@ def step(self, closure=None, grad_scaler=None): amaxes.append(amax) scale_invs.append(scale_inv) if self.master_weights: - p_main_of_fp8_model.append(p_master.data) + p_main_of_fp8_model.append(unscaled_state["master_param"].data) g_of_fp8_model.append(p_grad.data) - m_of_fp8_model.append(state["exp_avg"]) - v_of_fp8_model.append(state["exp_avg_sq"]) + m_of_fp8_model.append(unscaled_state["exp_avg"]) + v_of_fp8_model.append(unscaled_state["exp_avg_sq"]) elif p.dtype in [torch.float16, torch.bfloat16]: has_fp16 = has_fp16 or p.dtype == torch.float16 has_bf16 = has_bf16 or p.dtype == torch.bfloat16 p_f16_model.append(p.data) if self.master_weights: - p_main_of_f16_model.append(p_master.data) + p_main_of_f16_model.append(unscaled_state["master_param"].data) g_of_f16_model.append(p_grad.data) - m_of_f16_model.append(state["exp_avg"]) - v_of_f16_model.append(state["exp_avg_sq"]) + m_of_f16_model.append(unscaled_state["exp_avg"]) + v_of_f16_model.append(unscaled_state["exp_avg_sq"]) elif p.dtype == torch.float32: p_f32_model.append(p.data) g_of_f32_model.append(p_grad.data) - m_of_f32_model.append(state["exp_avg"]) - v_of_f32_model.append(state["exp_avg_sq"]) + m_of_f32_model.append(unscaled_state["exp_avg"]) + v_of_f32_model.append(unscaled_state["exp_avg_sq"]) else: - raise RuntimeError("FusedAdam only support model weights in fp16/bf16 and fp8") + raise RuntimeError( + "FusedAdam only support model weights in fp32, fp16, bf16 and fp8" + ) if self.capturable and len(p_fp8_model) > 0: raise RuntimeError( @@ -389,4 +628,15 @@ def apply_multi_tensor_adam(adam_func, tensor_lists, inv_scale=None, out_dtype=N tensor_lists = [g_of_f32_model, p_f32_model, m_of_f32_model, v_of_f32_model] apply_multi_tensor_adam(self.multi_tensor_adam, tensor_lists) + # Scaling + for name in ["exp_avg", "exp_avg_sq", "master_param"]: + if len(unscaled_lists[name]) > 0: + for unscaled, scaled, scale in zip( + unscaled_lists[name], scaled_lists[name], state_scales[name] + ): + self._apply_scale(name, unscaled, scaled, scale) + + # Try to reclaim the temporary fp32 buffers. + del unscaled_lists + return loss diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 110059d745..7ace68a222 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -74,30 +74,6 @@ def backward( return grad, None -def post_optimizer_step_fwd_amax_reduction(param: Float8Tensor) -> None: - """Amax scale and update when there is at least 1 trainable FP8 parameter.""" - param_id = id(param._data) - - if param_id not in FP8GlobalStateManager.fp8_param_to_autocast: - return - - autocast_key = FP8GlobalStateManager.fp8_param_to_autocast[param_id] - - if autocast_key not in FP8GlobalStateManager.autocast_to_fp8_params: - return - - if autocast_key in updated_fp8_params: - updated_fp8_params[autocast_key].add(param_id) - else: - updated_fp8_params[autocast_key] = {param_id} - - current_fp8_params_set = FP8GlobalStateManager.autocast_to_fp8_params[autocast_key] - # All FP8 trainable parameters have been updated. - if updated_fp8_params[autocast_key] == current_fp8_params_set: - FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=True, fp8_weights=True) - del updated_fp8_params[autocast_key] - - class _ToFloat8Func(torch.autograd.Function): """Cast to FP8 from other dtype""" @@ -109,10 +85,12 @@ def forward( fp8_meta_forward: bool = True, fp8_meta_index: Optional[int] = None, fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, + data: Optional[torch.Tensor] = None, scale: Optional[torch.Tensor] = None, amax: Optional[torch.Tensor] = None, scale_inv: Optional[torch.Tensor] = None, with_transpose_cache: bool = False, + data_transpose: Optional[torch.Tensor] = None, ) -> Float8Tensor: # pylint: disable=missing-function-docstring @@ -125,7 +103,8 @@ def forward( device = torch.device("cuda") # FP8 data buffer - data = torch.empty(tensor.size(), dtype=torch.uint8, device=device) + if data is None: + data = torch.empty(tensor.size(), dtype=torch.uint8, device=device) # Check scale if scale is None and fp8_meta is None: @@ -140,8 +119,7 @@ def forward( scale_inv = scale_inv.to(device=device, dtype=torch.float32) # Transpose cache - data_transpose = None - if with_transpose_cache: + if data_transpose is None and with_transpose_cache: data_transpose = torch.empty( (data.size(-1), data.numel() // data.size(-1)), dtype=torch.uint8, @@ -172,7 +150,7 @@ def backward( ) -> Tuple[Optional[torch.Tensor], ...]: # pylint: disable=missing-function-docstring # Assume that we want gradients in full precision - return grad, None, None, None, None, None, None, None + return grad, None, None, None, None, None, None, None, None, None class _IdentityFunc(torch.autograd.Function): @@ -674,9 +652,6 @@ def quantize_( ) dst._transpose_invalid = False - # Callback hook to perform amax reduction after optimizer step - post_optimizer_step_fwd_amax_reduction(self) - return self @classmethod @@ -688,10 +663,12 @@ def to_float8( fp8_meta_forward: bool = True, fp8_meta_index: Optional[int] = None, fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, + data: Optional[torch.Tensor] = None, scale: Optional[torch.Tensor] = None, amax: Optional[torch.Tensor] = None, scale_inv: Optional[torch.Tensor] = None, with_transpose_cache: bool = False, + data_transpose: Optional[torch.Tensor] = None, ): """Construct Float8Tensor from plain PyTorch tensor""" return _ToFloat8Func.apply( @@ -700,10 +677,12 @@ def to_float8( fp8_meta_forward, fp8_meta_index, fp8_dtype, + data, scale, amax, scale_inv, with_transpose_cache, + data_transpose, ) def detach(self) -> Float8Tensor: