Skip to content

Commit

Permalink
Merge branch 'pytorch:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
juliagmt-google authored Oct 8, 2024
2 parents e309f1b + 1ac701f commit dca8110
Show file tree
Hide file tree
Showing 16 changed files with 358 additions and 13 deletions.
26 changes: 26 additions & 0 deletions .ci/tritonbench/install-triton-nightly.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/bin/bash
if [ -z "${BASE_CONDA_ENV}" ]; then
echo "ERROR: BASE_CONDA_ENV is not set"
exit 1
fi

if [ -z "${CONDA_ENV}" ]; then
echo "ERROR: CONDA_ENV is not set"
exit 1
fi

if [ -z "${SETUP_SCRIPT}" ]; then
echo "ERROR: SETUP_SCRIPT is not set"
exit 1
fi

CONDA_ENV=${BASE_CONDA_ENV} . "${SETUP_SCRIPT}"
conda activate "${BASE_CONDA_ENV}"
# Remove the conda env if exists
conda remove --name "${CONDA_ENV}" -y --all || true
conda create --name "${CONDA_ENV}" -y --clone "${BASE_CONDA_ENV}"
conda activate "${CONDA_ENV}"

. "${SETUP_SCRIPT}"
# Install the nightly openai/triton
pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ fi
parent_dir=$(dirname "$(readlink -f "$0")")/../..
cd ${parent_dir}

# Test TritonBench
# Test TritonBench installation
python install.py --userbenchmark triton --fbgemm --test
28 changes: 28 additions & 0 deletions .ci/tritonbench/test-operators.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/bin/bash
set -x

if [ -z "${SETUP_SCRIPT}" ]; then
echo "ERROR: SETUP_SCRIPT is not set"
exit 1
fi

. "${SETUP_SCRIPT}"

# Test Tritonbench operators
# TODO: test every operator, fwd+bwd
python run_benchmark.py triton --op launch_latency --mode fwd --num-inputs 1 --test-only
python run_benchmark.py triton --op addmm --mode fwd --num-inputs 1 --test-only
python run_benchmark.py triton --op gemm --mode fwd --num-inputs 1 --test-only
python run_benchmark.py triton --op sum --mode fwd --num-inputs 1 --test-only
python run_benchmark.py triton --op softmax --mode fwd --num-inputs 1 --test-only
python run_benchmark.py triton --op layer_norm --mode fwd --num-inputs 1 --test-only


# Segfault
# python run_benchmark.py triton --op flash_attention --mode fwd --num-inputs 1 --test-only

# CUDA OOM
# python run_benchmark.py triton --op jagged_layer_norm --mode fwd --num-inputs 1 --test-only
# python run_benchmark.py triton --op jagged_mean --mode fwd --num-inputs 1 --test-only
# python run_benchmark.py triton --op jagged_softmax --mode fwd --num-inputs 1 --test-only
# python run_benchmark.py triton --op jagged_sum --mode fwd --num-inputs 1 --test-only
14 changes: 14 additions & 0 deletions .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
name: TorchBench PR Test
on:
pull_request:
# ignore tritonbench paths
paths-ignore:
- 'torchbenchmark/operators/*'
- 'torchbenchmark/util/kernels/*'
- 'torchbenchmark/util/triton_op.py'
- 'userbenchmark/triton/*'
- '.ci/tritonbench/*'
workflow_dispatch:
push:
branches:
- main
# ignore tritonbench paths
paths-ignore:
- 'torchbenchmark/operators/*'
- 'torchbenchmark/util/kernels/*'
- 'torchbenchmark/util/triton_op.py'
- 'userbenchmark/triton/*'
- '.ci/tritonbench/*'

jobs:
cpu-test:
Expand Down
63 changes: 63 additions & 0 deletions .github/workflows/tritonbench-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
name: Tritonbench PR Test on Triton nightly
on:
pull_request:
paths:
- 'torchbenchmark/operators/*'
- 'torchbenchmark/util/kernels/*'
- 'torchbenchmark/util/triton_op.py'
- 'userbenchmark/triton/*'
- '.ci/tritonbench/*'
workflow_dispatch:
push:
branches:
- main
paths:
- 'torchbenchmark/operators/*'
- 'torchbenchmark/util/kernels/*'
- 'torchbenchmark/util/triton_op.py'
- 'userbenchmark/triton/*'
- '.ci/tritonbench/*'

jobs:
cuda-test:
# Don't run on forked repos
if: github.repository_owner == 'pytorch'
runs-on: [a100-runner]
timeout-minutes: 240
environment: docker-s3-upload
env:
BASE_CONDA_ENV: "torchbench"
CONDA_ENV: "tritonbench-pr-test-cuda"
SETUP_SCRIPT: "/workspace/setup_instance.sh"
TEST_CONFIG: "cuda"
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
steps:
- name: Checkout TorchBench
uses: actions/checkout@v3
with:
submodules: 'true'
- name: Tune Nvidia GPU
run: |
sudo nvidia-smi -pm 1
sudo nvidia-smi -ac 1215,1410
sudo ldconfig
nvidia-smi
- name: Install triton-nightly
run: |
bash ./.ci/tritonbench/install-triton-nightly.sh
- name: Test Tritonbench install
run: |
bash ./.ci/tritonbench/test-install.sh
- name: Test Tritonbench operators
run: |
bash ./.ci/tritonbench/test-operators.sh
- name: Clean up Conda env
if: always()
run: |
. "${SETUP_SCRIPT}"
conda deactivate && conda deactivate
conda remove -n "${CONDA_ENV}" --all
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true
60 changes: 60 additions & 0 deletions .github/workflows/userbenchmark-a100-release.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
name: Release TorchBench Userbenchmark on A100
on:
pull_request:
paths:
- userbenchmark/release-test/*

jobs:
run-userbenchmark:
runs-on: [a100-runner]
timeout-minutes: 1440 # 24 hours
environment: docker-s3-upload
env:
BASE_CONDA_ENV: "torchbench"
CONDA_ENV: "userbenchmark-a100"
PLATFORM_NAME: "gcp_a100"
SETUP_SCRIPT: "/workspace/setup_instance.sh"
steps:
- name: Checkout TorchBench
uses: actions/checkout@v3
with:
path: benchmark
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
- name: Tune Nvidia GPU
run: |
sudo nvidia-smi -pm 1
sudo nvidia-smi -ac 1215,1410
nvidia-smi
- name: Clone and setup conda env
run: |
CONDA_ENV=${BASE_CONDA_ENV} . "${SETUP_SCRIPT}"
conda create --name "${CONDA_ENV}" --clone "${BASE_CONDA_ENV}"
- name: Install TorchBench
run: |
set -x
. "${SETUP_SCRIPT}"
pushd benchmark
python install.py
- name: Run user benchmark
run: |
set -x
. "${SETUP_SCRIPT}"
# remove old results
if [ -d benchmark-output ]; then rm -Rf benchmark-output; fi
pushd benchmark
release_version=$(cat userbenchmark/release-test/version.txt)
if [ -d .userbenchmark ]; then rm -Rf .userbenchmark; fi
python run_benchmark.py release-test -c ${release_version}
cp -r ./.userbenchmark/release-test ../benchmark-output
- name: Upload artifact
uses: actions/upload-artifact@v3
with:
name: TorchBench result
path: benchmark-output/
- name: Clean up Conda env
if: always()
run: |
. "${SETUP_SCRIPT}"
conda deactivate && conda deactivate
conda remove -n "${CONDA_ENV}" --all
12 changes: 12 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[build-system]
# Use legacy backend to import local packages in setup.py
build-backend = "setuptools.build_meta:__legacy__"


[tool.black]
line-length = 88
target-version = ["py38"]
exclude = '''/submodules/.*'''

[tool.usort]
excludes = ["**/submodules/**"]
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ pytest-benchmark
requests
tabulate
git+https://github.com/huggingface/pytorch-image-models.git@730b907
# this version of transformers is required as per this page https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
transformers==4.38.1
# this version of transformers is required by linger-kernel
# https://github.com/linkedin/Liger-Kernel/blob/main/pyproject.toml#L23
transformers==4.44.2
MonkeyType
psutil
pyyaml
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .operator import Operator
108 changes: 108 additions & 0 deletions torchbenchmark/operators/FusedLinearCrossEntropy/operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import argparse
from typing import Callable, Generator, List, Optional

import torch

from torchbenchmark.util.triton_op import BenchmarkOperator, register_benchmark

try:
from liger_kernel.transformers.fused_linear_cross_entropy import (
LigerFusedLinearCrossEntropyLoss,
)
except ModuleNotFoundError:
LigerFusedLinearCrossEntropyLoss = None

# Reference: https://github.com/linkedin/Liger-Kernel/blob/\
# 3d0653b035222cbb845435a1994854e4fd219107/benchmark/scripts/benchmark_fused_linear_cross_entropy.py


def parse_op_args(args: List[str]):
parser = argparse.ArgumentParser()
parser.add_argument("--hidden-size", type=int, default=4096, help="hidden size")
parser.add_argument("--vocab-size", type=int, default=128256, help="vocab size")
return parser.parse_args(args)


class TorchLMHeadCE(torch.nn.Module):
"""Ground truth implementation of the linear fused with torch based cross entropy loss.
:param H: hidden size
:param V: vocab size
:param ignore_index: index to ignore
:param reduction: reduction method
"""

def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=False, dtype=dtype
)
self.ce_loss = torch.nn.CrossEntropyLoss(
ignore_index=ignore_index, reduction="mean"
)

def forward(self, input, target):
logits = self.lin(input)
return self.ce_loss(logits, target)


class LigerLMHeadCE(torch.nn.Module):
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=False, dtype=dtype
)
self.ce_loss = LigerFusedLinearCrossEntropyLoss(
ignore_index=ignore_index, reduction="mean"
)

def forward(self, input, target):
return self.ce_loss(self.lin.weight, input, target)


class Operator(BenchmarkOperator):
def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
super().__init__(tb_args, extra_args)
op_args = parse_op_args(self.extra_args)
self.hidden_size = op_args.hidden_size
self.vocab_size = op_args.vocab_size
self.baseline_model = TorchLMHeadCE(
H=self.hidden_size, V=self.vocab_size, dtype=self.dtype
).to(self.device)
self.liger_model = LigerLMHeadCE(
H=self.hidden_size, V=self.vocab_size, dtype=self.dtype
).to(self.device)
self.use_cuda_graphs = False

def get_input_iter(self) -> Generator:
for BT in [2**i for i in range(12, 16)]:
_input = torch.randn(
BT,
self.hidden_size,
requires_grad=True,
dtype=self.dtype,
device=self.device,
)
target = torch.randint(
self.vocab_size, (BT, 1), dtype=torch.long, device=self.device
).squeeze(1)
yield _input, target

@register_benchmark(baseline=True)
def LMHeadCE(self, input, target) -> Callable:
return lambda: self.baseline_model(input, target)

@register_benchmark()
def LigerLMHeadCE(self, input, target) -> Callable:
return lambda: self.liger_model(input, target)

@register_benchmark()
def inductor_fused_linear_cross_entropy(self, input, target) -> Callable:
compiled = torch.compile(self.baseline_model, dynamic=False)
return lambda: compiled(input, target)

def get_bwd_fn(self, fwd_fn: Callable) -> Callable:
y = fwd_fn()
return lambda: y.backward(retain_graph=True)
Loading

0 comments on commit dca8110

Please sign in to comment.