Skip to content

Commit

Permalink
remove nvfuser scaled_mm
Browse files Browse the repository at this point in the history
due to the inaccurate results compared to `torch._scaled_mm`

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Dec 21, 2024
1 parent fce281b commit e5d26fc
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 110 deletions.
110 changes: 0 additions & 110 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2593,113 +2593,3 @@ def scaled_dot_product_flash_attention_grad(
execution_transform=scaled_dot_product_flash_attention,
grad_transform=scaled_dot_product_flash_attention_grad,
)


def _scaled_mm_check(
a: TensorProxy,
b: TensorProxy,
scale_a: TensorLike,
scale_b: TensorLike,
bias: TensorLike | None = None,
scale_result: TensorLike | None = None,
out_dtype: dtypes.dtype | None = None,
use_fast_accum: bool = False,
) -> bool:
from thunder.core.devices import to_torch_device

enable_scaled_mm: None | bool = get_compile_option("nv_enable_scaled_mm", "Enable nvFuser linear.")
if not enable_scaled_mm:
return False
if out_dtype is None:
return False
if not (scale_a.numel == 1 and scale_b.numel == 1 and scale_result is None and not use_fast_accum):
return False
if not (
a.dtype in dtypes.float_8bit_dtypes
and b.dtype in dtypes.float_8bit_dtypes
and scale_a.dtype == dtypes.float32
and scale_b.dtype == dtypes.float32
):
return False
if torch.cuda.get_device_capability(to_torch_device(a.device))[0] < 9:
return False
return True


def _scaled_mm_meta(
a: TensorProxy,
b: TensorProxy,
scale_a: TensorLike,
scale_b: TensorLike,
bias: TensorLike | None = None,
scale_result: TensorLike | None = None,
out_dtype: dtypes.dtype | None = None,
use_fast_accum: bool = False,
) -> TensorLike:
result_dtype = a.dtype if out_dtype is None else dtypes.to_dtype(out_dtype)
return TensorProxy(
like=a,
shape=(a.shape[0], b.shape[1]),
device=a.device,
dtype=result_dtype,
)


def _nv_scaled_mm(
a: TensorProxy,
b: TensorProxy,
scale_a: TensorLike,
scale_b: TensorLike,
bias: TensorLike | None = None,
scale_result: TensorLike | None = None,
out_dtype: dtypes.dtype | None = None,
use_fast_accum: bool = False,
*,
fd: FusionDefinition,
lc_to_nv_map: dict,
) -> Any:
nv_a = getnv(a, fd, lc_to_nv_map)
nv_b = getnv(b, fd, lc_to_nv_map)
nvscale_a = getnv(scale_a, fd, lc_to_nv_map)
nvscale_b = getnv(scale_b, fd, lc_to_nv_map)

cast_a = fd.ops.cast(nv_a, DataType.Float)
cast_b = fd.ops.cast(nv_b, DataType.Float)
scaled_cast_a = fd.ops.mul(cast_a, nvscale_a)
scaled_cast_b = fd.ops.mul(cast_b, nvscale_b)

output: Any
if bias is not None:
nvbias = getnv(bias, fd, lc_to_nv_map)
output = fd.ops.linear(scaled_cast_a, scaled_cast_b, nvbias)
else:
output = fd.ops.matmul(scaled_cast_a, scaled_cast_b)
return fd.ops.cast(output, lcdtype_to_nvdtype(dtypes.to_dtype(out_dtype)))


def _scaled_mm(
a: TensorProxy,
b: TensorProxy,
scale_a: TensorLike,
scale_b: TensorLike,
bias: TensorLike | None = None,
scale_result: TensorLike | None = None,
out_dtype: dtypes.dtype | None = None,
use_fast_accum: bool = False,
):
return nv__scaled_mm(a, b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum)


nv__scaled_mm = ex.register_operator(
"nv__scaled_mm",
meta=_scaled_mm_meta,
fn=_nv_scaled_mm,
)
register_supported(nv__scaled_mm.id, _nv_scaled_mm, None)


ex.register_supported(
ltorch._scaled_mm,
checker=_scaled_mm_check,
execution_transform=_scaled_mm,
)
3 changes: 3 additions & 0 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3588,6 +3588,9 @@ def matmul(a: TensorLike, b: TensorLike, /) -> TensorLike:
return prims.matmul(a, b)


# TODO(crcrpar): Add nvfuser support of `matmul(a.float() * scale_a, b.float() * scale_b) + bias`
# So far I haven't managed to get a nice result from nvfuser region as I left
# https://github.com/Lightning-AI/lightning-thunder/pull/1415/files#r1892875183
# reference: https://github.com/pytorch/pytorch/blob/6d4cd3e/torch/_meta_registrations.py#L5566
@torchsymbol(torch._scaled_mm, is_method=False)
def _scaled_mm(
Expand Down

0 comments on commit e5d26fc

Please sign in to comment.