Skip to content

Commit

Permalink
Enable dynamic M grouped gemm (#3444)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3444

X-link: facebookresearch/FBGEMM#530

This diff adds support for true dynamic M as is found in grouped_gemm. To do so, we add a new `zero_start_index_M` argument that must be provided by the user and indicates the number of non-zero M in each tensor. One nice thing about this approach is that we can now do a single kernel call to set up the gemm arguments.

We make `zero_start_index_M` optional as it requires fixed N and K. When N and K vary across group, we use the previous static shape approach.

Reviewed By: bradleyhd, jiawenliu64

Differential Revision: D66682886

fbshipit-source-id: 9c4554dba9becf33fcc87cd1b01266fead716916
  • Loading branch information
jwfromm authored and facebook-github-bot committed Dec 19, 2024
1 parent 0b1739c commit 804a499
Show file tree
Hide file tree
Showing 4 changed files with 247 additions and 53 deletions.
21 changes: 17 additions & 4 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def benchmark_grouped(
kernels: Optional[List[str]] = None,
bench_quantize: bool = False,
use_rotating_buffer_bench: bool = False,
use_cuda_graph: bool = True,
) -> Dict[str, Any]:
num_groups = len(m)
# Create input tensors.
Expand Down Expand Up @@ -92,6 +93,8 @@ def benchmark_grouped(
quantized_vals = quantize_op.quantize(A, B)
# Compute the output given quantized values.
output = quantize_op.compute(*quantized_vals)
# Some kernels may pad output, just take the first m values of each row.
output = [o[: m[i]] for i, o in enumerate(output)]
# Compare the quantize op output to reference as a sanity check.
sim_check: float = 0
for i in range(num_groups):
Expand All @@ -107,14 +110,14 @@ def benchmark_grouped(
B,
bench_quantize=True,
use_rotating_buffer_bench=use_rotating_buffer_bench,
use_cuda_graph=True,
use_cuda_graph=use_cuda_graph,
)
else:
ms_runtime = quantize_op.benchmark(
*quantized_vals,
bench_quantize=False,
use_rotating_buffer_bench=use_rotating_buffer_bench,
use_cuda_graph=True,
use_cuda_graph=use_cuda_graph,
)

# Print out results for this op.
Expand All @@ -124,8 +127,8 @@ def benchmark_grouped(
tflops += 2 * b[i] * m[i] * n[i] * k[i] / (ms_runtime / 1e3) / 1e12
gbps += (
(
quantized_vals[0][i].numel()
* quantized_vals[0][i].element_size()
quantized_vals[0][i][: m[i]].numel()
* quantized_vals[0][i][: m[i]].element_size()
+ quantized_vals[1][i].numel()
* quantized_vals[1][i].element_size()
+ output[i].numel() * output[i].element_size()
Expand Down Expand Up @@ -156,6 +159,7 @@ def benchmark(
kernels: Optional[List[str]] = None,
bench_quantize: bool = False,
use_rotating_buffer_bench: bool = False,
use_cuda_graph: bool = True,
) -> Dict[str, Any]:
# Create input tensors.
if b > 1:
Expand Down Expand Up @@ -192,12 +196,14 @@ def benchmark(
B,
bench_quantize=True,
use_rotating_buffer_bench=use_rotating_buffer_bench,
use_cuda_graph=use_cuda_graph,
)
else:
ms_runtime = quantize_op.benchmark(
*quantized_vals,
bench_quantize=False,
use_rotating_buffer_bench=use_rotating_buffer_bench,
use_cuda_graph=use_cuda_graph,
)

# Print out results for this op.
Expand Down Expand Up @@ -316,6 +322,7 @@ def main(args: Any):
kernels,
args.bench_quantize,
args.use_rotating_buffer_bench,
not args.no_cuda_graph,
)
benchmark_results.append(quantize_measurements)
if args.export_csv:
Expand Down Expand Up @@ -377,6 +384,12 @@ def invoke_main() -> None:
help="If set, do grouped gemm. In this mode, M, N, and K are interpreted "
"as the size of groups. The length of each must be the same.",
)
parser.add_argument(
"--no_cuda_graph",
default=False,
action="store_true",
help="If set, do not use cuda graph for benchmarking.",
)
parser.add_argument(
"--use_rotating_buffer_bench",
default=False,
Expand Down
77 changes: 69 additions & 8 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import List, Tuple

import fbgemm_gpu.experimental.gen_ai # noqa: F401
import numpy as np

import torch
import triton # @manual=//triton:triton
Expand Down Expand Up @@ -467,24 +468,84 @@ class FP8RowwiseGroupedGemm(QuantizeOpBase):
FP8 grouped matmul with rowwise scaling.
"""

def quantize_fixed_nk(self, x, w):
group_size = len(x)
m_values = [i.shape[0] for i in x]
# Inputs for fixed nk mode must be contiguous, however in the benchmark
# script they typically are not. Do a little special processing to make them
# work. In practice this wont be needed.
# Start by padding along m dimension with zeros.
max_m = max(m_values)
xq = [
torch.nn.functional.pad(i, (0, 0, 0, max_m - i.shape[0]), value=0)
for i in x
]
# Stack inputs into groups.
xq = torch.stack(xq).contiguous()
wq = torch.stack(w).contiguous()
# Allocate output tensor.
output = torch.empty(
[xq.shape[0], xq.shape[1], wq.shape[1]],
dtype=torch.bfloat16,
device=xq.device,
)
# Apply quantization.
xq, x_scale = quantize_fp8_row(xq)
wq, w_scale = quantize_fp8_row(wq)
# View these unified tensors as lists of tensors.
xq = [x.squeeze() for x in xq.split(1, dim=0)]
wq = [w.squeeze() for w in wq.split(1, dim=0)]
output = [o.squeeze() for o in output.split(1, dim=0)]
x_scale = [xs.squeeze() for xs in x_scale.view(group_size, -1).split(1, dim=0)]
w_scale = [ws.squeeze() for ws in w_scale.view(group_size, -1).split(1, dim=0)]

# Return processed tensors.
return (
xq,
wq,
x_scale,
w_scale,
torch.tensor(m_values).to(dtype=torch.int32, device=xq[0].device),
output,
)

def quantize(self, x, w):
# Quantize both input tensors.
# Handle both grouped and standard gemm.
assert isinstance(
x, (list, tuple)
), "Inputs to group gemm must be a list of tensors."

# First check if N and K are fixed.
m_values = [i.shape[0] for i in x]
n_values = [i.shape[0] for i in w]
k_values = [i.shape[1] for i in w]
# if so, do specialized version of initialization.
if len(np.unique(n_values)) == 1 and len(np.unique(k_values)) == 1:
return self.quantize_fixed_nk(x, w)

# Otherwise handle in eager mode.
xq, x_scale = zip(*[quantize_fp8_row(i) for i in x])
wq, w_scale = zip(*[quantize_fp8_row(i) for i in w])
return xq, wq, x_scale, w_scale

def compute(self, xq, wq, x_scale, w_scale, kernel_name=None):
output = [
torch.empty(m, n, device=xq[0].device, dtype=torch.bfloat16)
for m, n in zip(m_values, n_values)
]
m_values = None
return xq, wq, x_scale, w_scale, m_values, output

def compute(self, xq, wq, x_scale, w_scale, m_values, output, kernel_name=None):
return torch.ops.fbgemm.f8f8bf16_rowwise_grouped(
xq, wq, x_scale, w_scale, kernel_name=kernel_name
xq,
wq,
x_scale,
w_scale,
zero_start_index_M=m_values,
output=output,
kernel_name=kernel_name,
)

def quantize_and_compute(self, x, w):
xq, wq, x_scale, w_scale = self.quantize(x, w)
return self.compute(xq, wq, x_scale, w_scale)
xq, wq, x_scale, w_scale, m_values, output = self.quantize(x, w)
return self.compute(xq, wq, x_scale, w_scale, m_values, output)

@property
def name(self) -> str:
Expand Down
Loading

0 comments on commit 804a499

Please sign in to comment.