Skip to content

Commit

Permalink
fix check and add missing cast
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Dec 16, 2024
1 parent 21f803e commit c026ee4
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2579,6 +2579,8 @@ def _scaled_matmul_check(
out_dtype: dtypes.dtype | None = None,
use_fast_accum: bool = False,
) -> bool:
from thunder.core.devices import to_torch_device

enable_matmul: None | bool = get_compile_option("nv_enable_matmul", "Enable nvFuser matmul.")
enable_linear: None | bool = get_compile_option("nv_enable_linear", "Enable nvFuser linear.")
if not (enable_matmul or enable_linear):
Expand All @@ -2592,7 +2594,8 @@ def _scaled_matmul_check(
and scale_b.dtype == dtypes.float32
):
return False

if torch.cuda.get_device_capability(to_torch_device(a.device))[0] < 9:
return False
return True


Expand All @@ -2617,19 +2620,21 @@ def _scaled_matmul(
if out_dtype is None:
out_dtype = a.dtype

out_nvdtype = lcdtype_to_nvdtype(out_dtype)
nvout_dtype = lcdtype_to_nvdtype(out_dtype)

cast_nva = fd.ops.cast(nva, DataType.Float)
cast_nvb = fd.ops.cast(nvb, DataType.Float)

scaled_cast_nva = fd.ops.div(cast_nva, scale_a)
scaled_cast_nvb = fd.ops.div(cast_nvb, scale_b)
scaled_cast_nva = fd.ops.mul(cast_nva, nvscale_a)
scaled_cast_nvb = fd.ops.mul(cast_nvb, nvscale_b)

output: Any
if bias is not None:
nvbias = getnv(bias, fd, lc_to_nv_map)
return fd.ops.linear(scaled_cast_nva, scaled_cast_nvb, nvbias)
output = fd.ops.linear(scaled_cast_nva, scaled_cast_nvb, nvbias)
else:
return fd.ops.matmul(scaled_cast_nva, scaled_cast_nvb)
output = fd.ops.matmul(scaled_cast_nva, scaled_cast_nvb)
return fd.ops.cast(output, nvout_dtype)


register_supported(ltorch._scaled_mm, _scaled_matmul, _scaled_matmul_check)

0 comments on commit c026ee4

Please sign in to comment.