Skip to content

Commit

Permalink
Add SM90 CUTLASS 3.x kernels to perm102_bmm ops (facebookincubator#689)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookincubator#689

ATT. SM90 kernels are added to the following ops:

- `perm102_bmm_rcr`
- `perm102_bmm_rcr_bias`
- `perm102_bmm_rrr`
- `perm102_bmm_rrr_bias`

Reviewed By: chenyang78

Differential Revision: D45817288

fbshipit-source-id: 9096b0d920416bd1c5b4de377a831642f2fc6a18
  • Loading branch information
aakhundov authored and facebook-github-bot committed May 19, 2023
1 parent 02b04e4 commit d7e7996
Show file tree
Hide file tree
Showing 8 changed files with 388 additions and 151 deletions.
48 changes: 30 additions & 18 deletions python/aitemplate/backend/cuda/gemm_universal/bmm_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""
Common functions and templates for bmm-family ops
"""
from dataclasses import dataclass
import dataclasses

import jinja2

Expand Down Expand Up @@ -157,7 +157,7 @@
)


@dataclass
@dataclasses.dataclass
class Bmm_problem_info:
alpha_value: float = 1
beta_value: float = 0
Expand Down Expand Up @@ -576,6 +576,30 @@ def gen_profiler(
return common.build_profiler(file_pairs)


def add_elem_types_to_mm_info(mm_info, func_attrs):
"""
CUTLASS 3.x problem args require explicit I/O pointer types
(not void*). This function arugments the input and output
pointers in the mm_info with the appropriate elem_input_type
and elem_output_type casts.
"""
backend_spec = CUDASpec()
elem_input_type = backend_spec.dtype_to_lib_type(
func_attrs["inputs"][0]._attrs["dtype"]
)
elem_output_type = backend_spec.dtype_to_lib_type(
func_attrs["outputs"][0]._attrs["dtype"]
)

return dataclasses.replace(
mm_info,
a_ptr=f"({elem_input_type}*)({mm_info.a_ptr})",
b_ptr=f"({elem_input_type}*)({mm_info.b_ptr})",
bias_ptr=f"({elem_output_type}*)({mm_info.bias_ptr})",
c_ptr=f"({elem_output_type}*)({mm_info.c_ptr})",
)


def default_gen_profiler(
func_attrs,
workdir,
Expand Down Expand Up @@ -603,23 +627,11 @@ def default_gen_profiler(
problem_args = PROBLEM_ARGS_TEMPLATE.render(
mm_info=default_mm_info,
)

backend_spec = CUDASpec()
elem_input_type = backend_spec.dtype_to_lib_type(
func_attrs["inputs"][0]._attrs["dtype"]
)
elem_output_type = backend_spec.dtype_to_lib_type(
func_attrs["outputs"][0]._attrs["dtype"]
)

# CUTLASS 3.x problem args require explicit I/O pointer types (not void*)
default_mm_info.a_ptr = f"({elem_input_type}*)({default_mm_info.a_ptr})"
default_mm_info.b_ptr = f"({elem_input_type}*)({default_mm_info.b_ptr})"
default_mm_info.bias_ptr = f"({elem_output_type}*)({default_mm_info.bias_ptr})"
default_mm_info.c_ptr = f"({elem_output_type}*)({default_mm_info.c_ptr})"

problem_args_cutlass_3x = PROBLEM_ARGS_TEMPLATE_CUTLASS_3X.render(
mm_info=default_mm_info,
mm_info=add_elem_types_to_mm_info(
mm_info=default_mm_info,
func_attrs=func_attrs,
),
)

return gen_profiler(
Expand Down
21 changes: 4 additions & 17 deletions python/aitemplate/backend/cuda/gemm_universal/bmm_xxx_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@


from aitemplate.backend import registry
from aitemplate.backend.backend_spec import CUDASpec
from aitemplate.backend.common import gemm_common
from aitemplate.backend.cuda.gemm_universal import bmm_common, common
from aitemplate.backend.cuda.gemm_universal.bmm_xxx import _get_problem_args, get_config
Expand Down Expand Up @@ -108,23 +107,11 @@ def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict):
problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(
mm_info=default_mm_info,
)

backend_spec = CUDASpec()
elem_input_type = backend_spec.dtype_to_lib_type(
func_attrs["inputs"][0]._attrs["dtype"]
)
elem_output_type = backend_spec.dtype_to_lib_type(
func_attrs["outputs"][0]._attrs["dtype"]
)

# CUTLASS 3.x problem args require explicit I/O pointer types (not void*)
default_mm_info.a_ptr = f"({elem_input_type}*)({default_mm_info.a_ptr})"
default_mm_info.b_ptr = f"({elem_input_type}*)({default_mm_info.b_ptr})"
default_mm_info.bias_ptr = f"({elem_output_type}*)({default_mm_info.bias_ptr})"
default_mm_info.c_ptr = f"({elem_output_type}*)({default_mm_info.c_ptr})"

problem_args_cutlass_3x = bmm_common.PROBLEM_ARGS_TEMPLATE_CUTLASS_3X.render(
mm_info=default_mm_info,
mm_info=bmm_common.add_elem_types_to_mm_info(
mm_info=default_mm_info,
func_attrs=func_attrs,
),
)

return bmm_common.gen_profiler(
Expand Down
51 changes: 37 additions & 14 deletions python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rcr.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def _get_default_problem_info(**kwargs):
"ldb": "K",
"ldbias": "N * B",
"ldc": "N * B",
"a_row_major": True,
"b_row_major": False,
"c_row_major": True,
}
for k, v in kwargs.items():
problem_args[k] = v
Expand Down Expand Up @@ -63,6 +66,9 @@ def _get_strided_problem_info(func_attrs):
ldb="K",
ldbias="output_stride",
ldc="output_stride",
a_row_major=True,
b_row_major=False,
c_row_major=True,
)


Expand Down Expand Up @@ -112,7 +118,10 @@ def fproc(op):
epilogue_name=func_attrs["epilogue"],
)

func_attrs["op_instance"] = common.extract_config(fproc)
func_attrs["op_instance"] = common.extract_config(
f_proc_op=fproc,
include_cutlass_3x_ops=True,
)


@registry.reg("cuda.perm102_bmm_rcr.gen_profiler")
Expand All @@ -127,15 +136,22 @@ def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict):
problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(
mm_info=mm_info,
)
problem_args_cutlass_3x = bmm_common.PROBLEM_ARGS_TEMPLATE_CUTLASS_3X.render(
mm_info=bmm_common.add_elem_types_to_mm_info(
mm_info=mm_info,
func_attrs=func_attrs,
),
)

return bmm_common.gen_profiler(
func_attrs,
workdir,
profiler_filename,
dim_info_dict,
common.SRC_TEMPLATE,
problem_args,
args_parser,
func_attrs=func_attrs,
workdir=workdir,
profiler_filename=profiler_filename,
dim_info_dict=dim_info_dict,
src_template=common.SRC_TEMPLATE,
problem_args=problem_args,
problem_args_cutlass_3x=problem_args_cutlass_3x,
args_parser=args_parser,
)


Expand All @@ -151,14 +167,21 @@ def gen_function(
problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(
mm_info=bmm_problem_info,
)
problem_args_cutlass_3x = bmm_common.PROBLEM_ARGS_TEMPLATE_CUTLASS_3X.render(
mm_info=bmm_common.add_elem_types_to_mm_info(
mm_info=bmm_problem_info,
func_attrs=func_attrs,
),
)

return bmm_common.gen_function(
func_attrs,
exec_cond_template,
problem_args,
dim_info_dict,
"", # input_addr_calculator
get_output_addr_calculator(func_attrs),
func_attrs=func_attrs,
exec_cond_template=exec_cond_template,
problem_args=problem_args,
problem_args_cutlass_3x=problem_args_cutlass_3x,
dim_info_dict=dim_info_dict,
input_addr_calculator="",
output_addr_calculator=get_output_addr_calculator(func_attrs),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def _get_default_problem_info(**kwargs):
"ldb": "K",
"ldbias": "0",
"ldc": "N * B",
"a_row_major": True,
"b_row_major": False,
"c_row_major": True,
}
for k, v in kwargs.items():
problem_args[k] = v
Expand Down Expand Up @@ -73,6 +76,9 @@ def _get_strided_problem_info(func_attrs):
ldb="K",
ldbias="0",
ldc="output_stride",
a_row_major=True,
b_row_major=False,
c_row_major=True,
)


Expand All @@ -93,15 +99,22 @@ def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict):
problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(
mm_info=mm_info,
)
problem_args_cutlass_3x = bmm_common.PROBLEM_ARGS_TEMPLATE_CUTLASS_3X.render(
mm_info=bmm_common.add_elem_types_to_mm_info(
mm_info=mm_info,
func_attrs=func_attrs,
),
)

return bmm_common.gen_profiler(
func_attrs,
workdir,
profiler_filename,
dim_info_dict,
common_bias.SRC_TEMPLATE,
problem_args,
args_parser,
func_attrs=func_attrs,
workdir=workdir,
profiler_filename=profiler_filename,
dim_info_dict=dim_info_dict,
src_template=common_bias.SRC_TEMPLATE,
problem_args=problem_args,
problem_args_cutlass_3x=problem_args_cutlass_3x,
args_parser=args_parser,
bias_ptr_arg="memory_pool->RequestTensorByIdx(3)",
)

Expand All @@ -118,16 +131,23 @@ def gen_function(
problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(
mm_info=bmm_problem_info,
)
problem_args_cutlass_3x = bmm_common.PROBLEM_ARGS_TEMPLATE_CUTLASS_3X.render(
mm_info=bmm_common.add_elem_types_to_mm_info(
mm_info=bmm_problem_info,
func_attrs=func_attrs,
),
)

input_ndims = len(func_attrs["input_accessors"][0].original_shapes)
weight_ndims = len(func_attrs["input_accessors"][1].original_shapes)
output_ndims = len(func_attrs["output_accessors"][0].original_shapes)

return common.gen_function(
func_attrs,
common_bias.SRC_TEMPLATE,
exec_cond_template,
problem_args,
func_attrs=func_attrs,
src_template=common_bias.SRC_TEMPLATE,
exec_cond_template=exec_cond_template,
problem_args=problem_args,
problem_args_cutlass_3x=problem_args_cutlass_3x,
input_ndims=input_ndims,
weight_ndims=weight_ndims,
output_ndims=output_ndims,
Expand Down
51 changes: 37 additions & 14 deletions python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rrr.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def _get_default_problem_info(**kwargs):
"ldb": "N",
"ldbias": "N * B",
"ldc": "N * B",
"a_row_major": True,
"b_row_major": True,
"c_row_major": True,
}
for k, v in kwargs.items():
problem_args[k] = v
Expand Down Expand Up @@ -66,6 +69,9 @@ def _get_strided_problem_info(func_attrs):
ldb="N",
ldbias="output_stride",
ldc="output_stride",
a_row_major=True,
b_row_major=True,
c_row_major=True,
)


Expand All @@ -83,7 +89,10 @@ def fproc(op):
epilogue_name=func_attrs["epilogue"],
)

func_attrs["op_instance"] = common.extract_config(fproc)
func_attrs["op_instance"] = common.extract_config(
f_proc_op=fproc,
include_cutlass_3x_ops=True,
)


@registry.reg("cuda.perm102_bmm_rrr.gen_profiler")
Expand All @@ -98,15 +107,22 @@ def gen_profiler(func_attrs, workdir, profiler_filename, dim_info_dict):
problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(
mm_info=mm_info,
)
problem_args_cutlass_3x = bmm_common.PROBLEM_ARGS_TEMPLATE_CUTLASS_3X.render(
mm_info=bmm_common.add_elem_types_to_mm_info(
mm_info=mm_info,
func_attrs=func_attrs,
),
)

return bmm_common.gen_profiler(
func_attrs,
workdir,
profiler_filename,
dim_info_dict,
common.SRC_TEMPLATE,
problem_args,
args_parser,
func_attrs=func_attrs,
workdir=workdir,
profiler_filename=profiler_filename,
dim_info_dict=dim_info_dict,
src_template=common.SRC_TEMPLATE,
problem_args=problem_args,
problem_args_cutlass_3x=problem_args_cutlass_3x,
args_parser=args_parser,
)


Expand All @@ -122,14 +138,21 @@ def gen_function(
problem_args = bmm_common.PROBLEM_ARGS_TEMPLATE.render(
mm_info=bmm_problem_info,
)
problem_args_cutlass_3x = bmm_common.PROBLEM_ARGS_TEMPLATE_CUTLASS_3X.render(
mm_info=bmm_common.add_elem_types_to_mm_info(
mm_info=bmm_problem_info,
func_attrs=func_attrs,
),
)

return bmm_common.gen_function(
func_attrs,
exec_cond_template,
problem_args,
dim_info_dict,
"", # input_addr_calculator
get_output_addr_calculator(func_attrs),
func_attrs=func_attrs,
exec_cond_template=exec_cond_template,
problem_args=problem_args,
problem_args_cutlass_3x=problem_args_cutlass_3x,
dim_info_dict=dim_info_dict,
input_addr_calculator="",
output_addr_calculator=get_output_addr_calculator(func_attrs),
)


Expand Down
Loading

0 comments on commit d7e7996

Please sign in to comment.