diff --git a/.ci/tritonbench/install-triton-nightly.sh b/.ci/tritonbench/install-triton-nightly.sh new file mode 100644 index 0000000000..4d79004f3b --- /dev/null +++ b/.ci/tritonbench/install-triton-nightly.sh @@ -0,0 +1,26 @@ +#!/bin/bash +if [ -z "${BASE_CONDA_ENV}" ]; then + echo "ERROR: BASE_CONDA_ENV is not set" + exit 1 +fi + +if [ -z "${CONDA_ENV}" ]; then + echo "ERROR: CONDA_ENV is not set" + exit 1 +fi + +if [ -z "${SETUP_SCRIPT}" ]; then + echo "ERROR: SETUP_SCRIPT is not set" + exit 1 +fi + +CONDA_ENV=${BASE_CONDA_ENV} . "${SETUP_SCRIPT}" +conda activate "${BASE_CONDA_ENV}" +# Remove the conda env if exists +conda remove --name "${CONDA_ENV}" -y --all || true +conda create --name "${CONDA_ENV}" -y --clone "${BASE_CONDA_ENV}" +conda activate "${CONDA_ENV}" + +. "${SETUP_SCRIPT}" +# Install the nightly openai/triton +pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly diff --git a/.ci/tritonbench/test.sh b/.ci/tritonbench/test-install.sh similarity index 87% rename from .ci/tritonbench/test.sh rename to .ci/tritonbench/test-install.sh index 34604aae42..383f7d4cda 100644 --- a/.ci/tritonbench/test.sh +++ b/.ci/tritonbench/test-install.sh @@ -8,5 +8,5 @@ fi parent_dir=$(dirname "$(readlink -f "$0")")/../.. cd ${parent_dir} -# Test TritonBench +# Test TritonBench installation python install.py --userbenchmark triton --fbgemm --test diff --git a/.ci/tritonbench/test-operators.sh b/.ci/tritonbench/test-operators.sh new file mode 100644 index 0000000000..40af2f18f5 --- /dev/null +++ b/.ci/tritonbench/test-operators.sh @@ -0,0 +1,28 @@ +#!/bin/bash +set -x + +if [ -z "${SETUP_SCRIPT}" ]; then + echo "ERROR: SETUP_SCRIPT is not set" + exit 1 +fi + +. "${SETUP_SCRIPT}" + +# Test Tritonbench operators +# TODO: test every operator, fwd+bwd +python run_benchmark.py triton --op launch_latency --mode fwd --num-inputs 1 --test-only +python run_benchmark.py triton --op addmm --mode fwd --num-inputs 1 --test-only +python run_benchmark.py triton --op gemm --mode fwd --num-inputs 1 --test-only +python run_benchmark.py triton --op sum --mode fwd --num-inputs 1 --test-only +python run_benchmark.py triton --op softmax --mode fwd --num-inputs 1 --test-only +python run_benchmark.py triton --op layer_norm --mode fwd --num-inputs 1 --test-only + + +# Segfault +# python run_benchmark.py triton --op flash_attention --mode fwd --num-inputs 1 --test-only + +# CUDA OOM +# python run_benchmark.py triton --op jagged_layer_norm --mode fwd --num-inputs 1 --test-only +# python run_benchmark.py triton --op jagged_mean --mode fwd --num-inputs 1 --test-only +# python run_benchmark.py triton --op jagged_softmax --mode fwd --num-inputs 1 --test-only +# python run_benchmark.py triton --op jagged_sum --mode fwd --num-inputs 1 --test-only diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 00417d454e..815506b7cb 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -1,10 +1,24 @@ name: TorchBench PR Test on: pull_request: + # ignore tritonbench paths + paths-ignore: + - 'torchbenchmark/operators/*' + - 'torchbenchmark/util/kernels/*' + - 'torchbenchmark/util/triton_op.py' + - 'userbenchmark/triton/*' + - '.ci/tritonbench/*' workflow_dispatch: push: branches: - main + # ignore tritonbench paths + paths-ignore: + - 'torchbenchmark/operators/*' + - 'torchbenchmark/util/kernels/*' + - 'torchbenchmark/util/triton_op.py' + - 'userbenchmark/triton/*' + - '.ci/tritonbench/*' jobs: cpu-test: diff --git a/.github/workflows/tritonbench-test.yml b/.github/workflows/tritonbench-test.yml new file mode 100644 index 0000000000..7e58ccf39f --- /dev/null +++ b/.github/workflows/tritonbench-test.yml @@ -0,0 +1,63 @@ +name: Tritonbench PR Test on Triton nightly +on: + pull_request: + paths: + - 'torchbenchmark/operators/*' + - 'torchbenchmark/util/kernels/*' + - 'torchbenchmark/util/triton_op.py' + - 'userbenchmark/triton/*' + - '.ci/tritonbench/*' + workflow_dispatch: + push: + branches: + - main + paths: + - 'torchbenchmark/operators/*' + - 'torchbenchmark/util/kernels/*' + - 'torchbenchmark/util/triton_op.py' + - 'userbenchmark/triton/*' + - '.ci/tritonbench/*' + +jobs: + cuda-test: + # Don't run on forked repos + if: github.repository_owner == 'pytorch' + runs-on: [a100-runner] + timeout-minutes: 240 + environment: docker-s3-upload + env: + BASE_CONDA_ENV: "torchbench" + CONDA_ENV: "tritonbench-pr-test-cuda" + SETUP_SCRIPT: "/workspace/setup_instance.sh" + TEST_CONFIG: "cuda" + HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + steps: + - name: Checkout TorchBench + uses: actions/checkout@v3 + with: + submodules: 'true' + - name: Tune Nvidia GPU + run: | + sudo nvidia-smi -pm 1 + sudo nvidia-smi -ac 1215,1410 + sudo ldconfig + nvidia-smi + - name: Install triton-nightly + run: | + bash ./.ci/tritonbench/install-triton-nightly.sh + - name: Test Tritonbench install + run: | + bash ./.ci/tritonbench/test-install.sh + - name: Test Tritonbench operators + run: | + bash ./.ci/tritonbench/test-operators.sh + - name: Clean up Conda env + if: always() + run: | + . "${SETUP_SCRIPT}" + conda deactivate && conda deactivate + conda remove -n "${CONDA_ENV}" --all + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true diff --git a/.github/workflows/userbenchmark-a100-release.yml b/.github/workflows/userbenchmark-a100-release.yml new file mode 100644 index 0000000000..6523af9691 --- /dev/null +++ b/.github/workflows/userbenchmark-a100-release.yml @@ -0,0 +1,60 @@ +name: Release TorchBench Userbenchmark on A100 +on: + pull_request: + paths: + - userbenchmark/release-test/* + +jobs: + run-userbenchmark: + runs-on: [a100-runner] + timeout-minutes: 1440 # 24 hours + environment: docker-s3-upload + env: + BASE_CONDA_ENV: "torchbench" + CONDA_ENV: "userbenchmark-a100" + PLATFORM_NAME: "gcp_a100" + SETUP_SCRIPT: "/workspace/setup_instance.sh" + steps: + - name: Checkout TorchBench + uses: actions/checkout@v3 + with: + path: benchmark + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + - name: Tune Nvidia GPU + run: | + sudo nvidia-smi -pm 1 + sudo nvidia-smi -ac 1215,1410 + nvidia-smi + - name: Clone and setup conda env + run: | + CONDA_ENV=${BASE_CONDA_ENV} . "${SETUP_SCRIPT}" + conda create --name "${CONDA_ENV}" --clone "${BASE_CONDA_ENV}" + - name: Install TorchBench + run: | + set -x + . "${SETUP_SCRIPT}" + pushd benchmark + python install.py + - name: Run user benchmark + run: | + set -x + . "${SETUP_SCRIPT}" + # remove old results + if [ -d benchmark-output ]; then rm -Rf benchmark-output; fi + pushd benchmark + release_version=$(cat userbenchmark/release-test/version.txt) + if [ -d .userbenchmark ]; then rm -Rf .userbenchmark; fi + python run_benchmark.py release-test -c ${release_version} + cp -r ./.userbenchmark/release-test ../benchmark-output + + - name: Upload artifact + uses: actions/upload-artifact@v3 + with: + name: TorchBench result + path: benchmark-output/ + - name: Clean up Conda env + if: always() + run: | + . "${SETUP_SCRIPT}" + conda deactivate && conda deactivate + conda remove -n "${CONDA_ENV}" --all diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000..da571fcfd0 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,12 @@ +[build-system] +# Use legacy backend to import local packages in setup.py +build-backend = "setuptools.build_meta:__legacy__" + + +[tool.black] +line-length = 88 +target-version = ["py38"] +exclude = '''/submodules/.*''' + +[tool.usort] +excludes = ["**/submodules/**"] diff --git a/requirements.txt b/requirements.txt index b9d52c5b24..aca426bb20 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,8 +10,9 @@ pytest-benchmark requests tabulate git+https://github.com/huggingface/pytorch-image-models.git@730b907 -# this version of transformers is required as per this page https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1 -transformers==4.38.1 +# this version of transformers is required by linger-kernel +# https://github.com/linkedin/Liger-Kernel/blob/main/pyproject.toml#L23 +transformers==4.44.2 MonkeyType psutil pyyaml diff --git a/torchbenchmark/operators/FusedLinearCrossEntropy/__init__.py b/torchbenchmark/operators/FusedLinearCrossEntropy/__init__.py new file mode 100644 index 0000000000..a77a295cc4 --- /dev/null +++ b/torchbenchmark/operators/FusedLinearCrossEntropy/__init__.py @@ -0,0 +1 @@ +from .operator import Operator diff --git a/torchbenchmark/operators/FusedLinearCrossEntropy/operator.py b/torchbenchmark/operators/FusedLinearCrossEntropy/operator.py new file mode 100644 index 0000000000..9b5ed35541 --- /dev/null +++ b/torchbenchmark/operators/FusedLinearCrossEntropy/operator.py @@ -0,0 +1,108 @@ +import argparse +from typing import Callable, Generator, List, Optional + +import torch + +from torchbenchmark.util.triton_op import BenchmarkOperator, register_benchmark + +try: + from liger_kernel.transformers.fused_linear_cross_entropy import ( + LigerFusedLinearCrossEntropyLoss, + ) +except ModuleNotFoundError: + LigerFusedLinearCrossEntropyLoss = None + +# Reference: https://github.com/linkedin/Liger-Kernel/blob/\ +# 3d0653b035222cbb845435a1994854e4fd219107/benchmark/scripts/benchmark_fused_linear_cross_entropy.py + + +def parse_op_args(args: List[str]): + parser = argparse.ArgumentParser() + parser.add_argument("--hidden-size", type=int, default=4096, help="hidden size") + parser.add_argument("--vocab-size", type=int, default=128256, help="vocab size") + return parser.parse_args(args) + + +class TorchLMHeadCE(torch.nn.Module): + """Ground truth implementation of the linear fused with torch based cross entropy loss. + + :param H: hidden size + :param V: vocab size + :param ignore_index: index to ignore + :param reduction: reduction method + """ + + def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): + super().__init__() + self.lin = torch.nn.Linear( + in_features=H, out_features=V, bias=False, dtype=dtype + ) + self.ce_loss = torch.nn.CrossEntropyLoss( + ignore_index=ignore_index, reduction="mean" + ) + + def forward(self, input, target): + logits = self.lin(input) + return self.ce_loss(logits, target) + + +class LigerLMHeadCE(torch.nn.Module): + def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): + super().__init__() + self.lin = torch.nn.Linear( + in_features=H, out_features=V, bias=False, dtype=dtype + ) + self.ce_loss = LigerFusedLinearCrossEntropyLoss( + ignore_index=ignore_index, reduction="mean" + ) + + def forward(self, input, target): + return self.ce_loss(self.lin.weight, input, target) + + +class Operator(BenchmarkOperator): + def __init__( + self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None + ): + super().__init__(tb_args, extra_args) + op_args = parse_op_args(self.extra_args) + self.hidden_size = op_args.hidden_size + self.vocab_size = op_args.vocab_size + self.baseline_model = TorchLMHeadCE( + H=self.hidden_size, V=self.vocab_size, dtype=self.dtype + ).to(self.device) + self.liger_model = LigerLMHeadCE( + H=self.hidden_size, V=self.vocab_size, dtype=self.dtype + ).to(self.device) + self.use_cuda_graphs = False + + def get_input_iter(self) -> Generator: + for BT in [2**i for i in range(12, 16)]: + _input = torch.randn( + BT, + self.hidden_size, + requires_grad=True, + dtype=self.dtype, + device=self.device, + ) + target = torch.randint( + self.vocab_size, (BT, 1), dtype=torch.long, device=self.device + ).squeeze(1) + yield _input, target + + @register_benchmark(baseline=True) + def LMHeadCE(self, input, target) -> Callable: + return lambda: self.baseline_model(input, target) + + @register_benchmark() + def LigerLMHeadCE(self, input, target) -> Callable: + return lambda: self.liger_model(input, target) + + @register_benchmark() + def inductor_fused_linear_cross_entropy(self, input, target) -> Callable: + compiled = torch.compile(self.baseline_model, dynamic=False) + return lambda: compiled(input, target) + + def get_bwd_fn(self, fwd_fn: Callable) -> Callable: + y = fwd_fn() + return lambda: y.backward(retain_graph=True) diff --git a/torchbenchmark/operators/fp8_gemm_rowwise/operator.py b/torchbenchmark/operators/fp8_gemm_rowwise/operator.py index b906980280..3db513b76b 100644 --- a/torchbenchmark/operators/fp8_gemm_rowwise/operator.py +++ b/torchbenchmark/operators/fp8_gemm_rowwise/operator.py @@ -27,12 +27,18 @@ def parse_args(args: List[str]) -> argparse.Namespace: "--no_fp8_fast_accum", dest="fp8_fast_accum", action="store_false" ) parser.add_argument("--no_use_tma", dest="use_tma", action="store_false") - args = parser.parse_args(args) - return args + parser.add_argument( + "--no_use_persistent", + dest="no_use_persistent", + action="store_true", + ) + parsed_args = parser.parse_args(args) + return parsed_args try: from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import ( + get_fp8_constants as get_fp8_constants, matmul_fp8_row as triton_fp8_row, ) @@ -52,7 +58,7 @@ def parse_args(args: List[str]) -> argparse.Namespace: from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import scale_fp8_row HAS_CUBLAS = True -except ImportError: +except (ImportError, IOError, AttributeError): HAS_CUBLAS = False @@ -79,7 +85,8 @@ def parse_args(args: List[str]) -> argparse.Namespace: (16384, 8192, 13312), ] -E4M3_MAX_POS: float = torch.finfo(torch.float8_e4m3fn).max +FP8_DTYPE, _, _, _ = get_fp8_constants() +E4M3_MAX_POS: float = torch.finfo(FP8_DTYPE).max EPS: float = 1e-12 FP16_MAX_POS: float = torch.finfo(torch.float16).max @@ -91,7 +98,7 @@ def fp8_row_quantize(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if x.dtype is torch.float16: scale = torch.clamp(scale, max=FP16_MAX_POS) xq = torch.clamp(x * scale[:, None], min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS).to( - torch.float8_e4m3fn + FP8_DTYPE ) return xq, scale.reciprocal().to(torch.float32) @@ -113,6 +120,7 @@ def __init__( self.shapes = BUILDIN_SHAPES self.fp8_fast_accum = addmm_args.fp8_fast_accum self.use_tma = addmm_args.use_tma + self.no_use_persistent = addmm_args.no_use_persistent @register_benchmark(enabled=HAS_TRITON, baseline=True) def _triton(self, xq, wq, x_scale, w_scale) -> Callable: @@ -123,6 +131,7 @@ def _triton(self, xq, wq, x_scale, w_scale) -> Callable: w_scale, fp8_fast_accum=self.fp8_fast_accum, tma_persistent=self.use_tma, + no_use_persistent=self.no_use_persistent, ) @register_benchmark(enabled=HAS_CUTLASS) diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py index 5257f0cce9..7205fdaaf4 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py @@ -805,6 +805,8 @@ class CompilationMetrics: remote_cache_time_saved_s: Optional[float] structured_logging_overhead_s: Optional[float] config_suppress_errors: Optional[bool] + config_inline_inbuilt_nn_modules: Optional[bool] + specialize_float: Optional[bool] @dataclasses.dataclass diff --git a/userbenchmark/release-test/run_release_test.sh b/userbenchmark/release-test/run_release_test.sh index 8eb7112ac5..6272a45249 100644 --- a/userbenchmark/release-test/run_release_test.sh +++ b/userbenchmark/release-test/run_release_test.sh @@ -16,12 +16,11 @@ fi . switch-cuda.sh "${CUDA_VERSION}" -if [[ ${CUDA_VERSION} == "12.1" ]]; then - pip install nvidia-cuda-nvcc-cu12 -fi nvcc --version sudo apt-get install bc +sudo apt-get install --reinstall time +which time # run mnist mkdir -p "${RESULT_DIR}/mnist" pushd "${EXAMPLES_DIR}/mnist" diff --git a/userbenchmark/release-test/version.txt b/userbenchmark/release-test/version.txt new file mode 100644 index 0000000000..437459cd94 --- /dev/null +++ b/userbenchmark/release-test/version.txt @@ -0,0 +1 @@ +2.5.0 diff --git a/userbenchmark/triton/install.py b/userbenchmark/triton/install.py index a762c402ef..0f5e0ea82d 100644 --- a/userbenchmark/triton/install.py +++ b/userbenchmark/triton/install.py @@ -66,6 +66,13 @@ def install_fa3(): subprocess.check_call(cmd, cwd=str(FA3_PATH.resolve())) +def install_liger(): + # Liger-kernel has a conflict dependency `triton` with pytorch, + # so we need to install it without dependencies + cmd = ["pip", "install", "liger-kernel", "--no-deps"] + subprocess.check_call(cmd) + + def install_tk(): try: from .tk.install import install_tk @@ -88,6 +95,7 @@ def install_tk(): ) parser.add_argument("--jax", action="store_true", help="Install jax nightly") parser.add_argument("--tk", action="store_true", help="Install ThunderKittens") + parser.add_argument("--liger", action="store_true", help="Install Liger-kernel") parser.add_argument("--test", action="store_true", help="Run test") args = parser.parse_args() @@ -105,3 +113,5 @@ def install_tk(): install_jax() if args.tk and not args.test: install_tk() + if args.liger and not args.test: + install_liger() diff --git a/userbenchmark/triton/run.py b/userbenchmark/triton/run.py index f8ab0821aa..ad2c58c2aa 100644 --- a/userbenchmark/triton/run.py +++ b/userbenchmark/triton/run.py @@ -29,7 +29,12 @@ def get_parser(args=None): parser = argparse.ArgumentParser(allow_abbrev=False) - parser.add_argument("--op", type=str, required=False, help="Operator to benchmark.") + parser.add_argument( + "--op", + type=str, + required=False, + help="Operators to benchmark. Split with comma if multiple.", + ) parser.add_argument( "--mode", choices=["fwd", "bwd", "fwd_bwd", "fwd_no_grad"], @@ -188,5 +193,11 @@ def run(args: List[str] = []): run_ci() return + if args.op: + ops = args.op.split(",") + else: + ops = [] with gpu_lockdown(args.gpu_lockdown): - _run(args, extra_args) + for op in ops: + args.op = op + _run(args, extra_args)