Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rocm support for moe tuning script #251

Merged
merged 8 commits into from
Nov 13, 2024
272 changes: 224 additions & 48 deletions benchmarks/kernels/benchmark_moe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import time
from datetime import datetime
from itertools import product
from typing import Any, Dict, List, Tuple, TypedDict

import ray
Expand All @@ -11,7 +12,10 @@

from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser
from vllm.utils import FlexibleArgumentParser, is_navi

FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm(
) and not is_navi() else torch.float8_e4m3fn


class BenchmarkConfig(TypedDict):
Expand Down Expand Up @@ -80,8 +84,8 @@ def benchmark_config(
a1_scale = torch.randn(1, dtype=torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32)

w1 = w1.to(torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fn)
w1 = w1.to(FP8_DTYPE)
w2 = w2.to(FP8_DTYPE)

input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)

Expand Down Expand Up @@ -141,35 +145,183 @@ def run():
return avg


def get_configs_compute_bound() -> List[Dict[str, int]]:
# Reduced search space for faster tuning.
# TODO(woosuk): Increase the search space and use a performance model to
# prune the search space.
def get_rocm_tuning_space(use_fp16):
block_mn_range = [16, 32, 64, 128, 256]
block_k_range = [16, 32, 64, 128, 256]
if not use_fp16:
block_k_range.remove(16) # BLOCK_K=16 not supported for fp8
num_warps_range = [1, 2, 4, 8]
group_m_range = [1, 4, 8, 16, 32]
# For now we see better perf with num_stages=0 for all gemm configs we care
# But keep this explicit so that we do not forget we may need to set it to
# other values in the future
num_stage_range = [0]
waves_per_eu_range = [0]
matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
kpack_range = [1, 2] if use_fp16 else []

param_ranges = {
"BLOCK_SIZE_M": block_mn_range,
"BLOCK_SIZE_N": block_mn_range,
"BLOCK_SIZE_K": block_k_range,
"GROUP_SIZE_M": group_m_range,
"num_warps": num_warps_range,
"num_stages": num_stage_range,
"waves_per_eu": waves_per_eu_range,
}
if use_fp16:
param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range
param_ranges["kpack"] = kpack_range

return param_ranges


def get_configs_compute_bound(use_fp16) -> List[Dict[str, int]]:
configs: List[BenchmarkConfig] = []
for num_stages in [2, 3, 4, 5]:
for block_m in [16, 32, 64, 128, 256]:
for block_k in [64, 128, 256]:
for block_n in [32, 64, 128, 256]:
for num_warps in [4, 8]:
for group_size in [1, 16, 32, 64]:
configs.append({
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_size,
"num_warps": num_warps,
"num_stages": num_stages,
})

if current_platform.is_rocm():
param_ranges = get_rocm_tuning_space(use_fp16)
else:
# Reduced search space for faster tuning.
# TODO(woosuk): Increase the search space and use a performance model to
# prune the search space.
block_m_range = [16, 32, 64, 128, 256]
block_n_range = [32, 64, 128, 256]
block_k_range = [64, 128, 256]
num_warps_range = [4, 8]
group_m_range = [1, 16, 32, 64]
num_stage_range = [2, 3, 4, 5]

param_ranges = {
"BLOCK_SIZE_M": block_m_range,
"BLOCK_SIZE_N": block_n_range,
"BLOCK_SIZE_K": block_k_range,
"GROUP_SIZE_M": group_m_range,
"num_warps": num_warps_range,
"num_stages": num_stage_range,
}

keys, values = zip(*param_ranges.items())
for config_values in product(*values):
config = dict(zip(keys, config_values))
configs.append(config)
return configs


def prune_search_space(num_tokens, shard_intermediate_size, hidden_size,
search_space, is_fp16):
N1, K1 = shard_intermediate_size, hidden_size
N2, K2 = hidden_size, shard_intermediate_size // 2
pruned_space_1 = prune_configs(num_tokens * 2, N1, K1, search_space,
is_fp16)
pruned_space_2 = prune_configs(num_tokens * 2, N2, K2, search_space,
is_fp16)
search_space = merge_unique_dicts(pruned_space_1, pruned_space_2)
return search_space


def prune_configs(M, N, K, configs, is_fp16=True):
pruned_configs = []
elemBytes_a = 2 if is_fp16 else 1
elemBytes_b = 2 if is_fp16 else 1

mfma = 16 if M < 32 or N < 32 else 32

# TODO (zhanglx): figure out the boundary between large and small gemms
large_gemm = False
if M >= 2048 and N >= 2048:
large_gemm = True

for config in configs:
BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
num_warps = config.get("num_warps")

if is_fp16:
matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
if matrix_instr_nonkdim > mfma:
continue
if mfma == 4 and BLOCK_SIZE_K < 64:
continue
# some layouts could not work properly in case
# number elements per thread is less 1
if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
continue
SPLIT_K = config.get("SPLIT_K", 1)
GROUP_M = config.get("GROUP_SIZE_M")
if is_fp16:
if (matrix_instr_nonkdim > BLOCK_SIZE_M
or matrix_instr_nonkdim > BLOCK_SIZE_N):
continue
if (matrix_instr_nonkdim >= M
and matrix_instr_nonkdim != BLOCK_SIZE_M):
continue
if (matrix_instr_nonkdim >= N
and matrix_instr_nonkdim != BLOCK_SIZE_N):
continue
# Skip BLOCK_SIZE that is too large compare to M/N
# unless BLOCK_SIZE is already small enough
if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
continue
if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
continue
# skip large split_k when not necessary
if SPLIT_K != 1 and not need_split_k(M, N, K):
continue
# skip split_k that leads to EVEN_K = false
leap = SPLIT_K * BLOCK_SIZE_K
modv = K % leap
if modv != 0:
continue
# skip large GROUP_M
if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
continue
# out of shared memory resource
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a +
BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b)
if LDS > 65536:
continue
# Skip small block sizes and num_warps for large gemm
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
if large_gemm:
if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
continue
if BLOCK_SIZE_K < 64:
continue
if num_warps < 4:
continue

pruned_configs.append(config)

return pruned_configs


def need_split_k(SIZE_M, SIZE_N, SIZE_K):
return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024


def merge_unique_dicts(list1, list2):
result = []
combined_list = list1.copy()
combined_list.extend(list2)
for dictionary in combined_list:
if dictionary not in result:
result.append(dictionary)
return result


@ray.remote(num_gpus=1)
class BenchmarkWorker:

def __init__(self, seed: int) -> None:
torch.set_default_device("cuda")
current_platform.seed_everything(seed)
self.seed = seed
# get the device id to allocate the tensors and kernels explicitly
# on the respective GPU ID
self.device_id = int(ray.get_gpu_ids()[0])

def benchmark(
self,
Expand Down Expand Up @@ -217,25 +369,33 @@ def tune(
) -> Dict[str, int]:
best_config = None
best_time = float("inf")
for config in tqdm(search_space):
try:
kernel_time = benchmark_config(config,
num_tokens,
num_experts,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a16,
num_iters=10)
except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile.
continue

if kernel_time < best_time:
best_time = kernel_time
best_config = config
if current_platform.is_rocm():
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
search_space = prune_search_space(num_tokens,
shard_intermediate_size,
hidden_size, search_space,
is_fp16)

with torch.cuda.device(self.device_id):
for config in tqdm(search_space):
try:
kernel_time = benchmark_config(config,
num_tokens,
num_experts,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a16,
num_iters=20)
except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile.
continue

if kernel_time < best_time:
best_time = kernel_time
best_config = config
now = datetime.now()
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
assert best_config is not None
Expand All @@ -244,12 +404,27 @@ def tune(

def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
return {
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
"num_warps": config["num_warps"],
"num_stages": config["num_stages"],
"BLOCK_SIZE_M":
config["BLOCK_SIZE_M"],
"BLOCK_SIZE_N":
config["BLOCK_SIZE_N"],
"BLOCK_SIZE_K":
config["BLOCK_SIZE_K"],
"GROUP_SIZE_M":
config["GROUP_SIZE_M"],
"num_warps":
config["num_warps"],
"num_stages":
config["num_stages"],
**({
"waves_per_eu": config["waves_per_eu"]
} if "waves_per_eu" in config else {}),
**({
"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]
} if "matrix_instr_nonkdim" in config else {}),
**({
"kpack": config["kpack"]
} if "kpack" in config else {}),
}


Expand Down Expand Up @@ -294,7 +469,7 @@ def main(args: argparse.Namespace):
shard_intermediate_size = 2 * intermediate_size // args.tp_size

hidden_size = config.hidden_size
dtype = config.torch_dtype
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16"

Expand Down Expand Up @@ -322,7 +497,8 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]:
return ray.get(outputs)

if args.tune:
search_space = get_configs_compute_bound()
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
search_space = get_configs_compute_bound(is_fp16)
print(f"Start tuning over {len(search_space)} configurations...")

start = time.time()
Expand Down