diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 67ec6553d4..593f637e92 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -2579,6 +2579,8 @@ def _scaled_matmul_check( out_dtype: dtypes.dtype | None = None, use_fast_accum: bool = False, ) -> bool: + from thunder.core.devices import to_torch_device + enable_matmul: None | bool = get_compile_option("nv_enable_matmul", "Enable nvFuser matmul.") enable_linear: None | bool = get_compile_option("nv_enable_linear", "Enable nvFuser linear.") if not (enable_matmul or enable_linear): @@ -2592,7 +2594,8 @@ def _scaled_matmul_check( and scale_b.dtype == dtypes.float32 ): return False - + if torch.cuda.get_device_capability(to_torch_device(a.device))[0] < 9: + return False return True @@ -2617,19 +2620,21 @@ def _scaled_matmul( if out_dtype is None: out_dtype = a.dtype - out_nvdtype = lcdtype_to_nvdtype(out_dtype) + nvout_dtype = lcdtype_to_nvdtype(out_dtype) cast_nva = fd.ops.cast(nva, DataType.Float) cast_nvb = fd.ops.cast(nvb, DataType.Float) - scaled_cast_nva = fd.ops.div(cast_nva, scale_a) - scaled_cast_nvb = fd.ops.div(cast_nvb, scale_b) + scaled_cast_nva = fd.ops.mul(cast_nva, nvscale_a) + scaled_cast_nvb = fd.ops.mul(cast_nvb, nvscale_b) + output: Any if bias is not None: nvbias = getnv(bias, fd, lc_to_nv_map) - return fd.ops.linear(scaled_cast_nva, scaled_cast_nvb, nvbias) + output = fd.ops.linear(scaled_cast_nva, scaled_cast_nvb, nvbias) else: - return fd.ops.matmul(scaled_cast_nva, scaled_cast_nvb) + output = fd.ops.matmul(scaled_cast_nva, scaled_cast_nvb) + return fd.ops.cast(output, nvout_dtype) register_supported(ltorch._scaled_mm, _scaled_matmul, _scaled_matmul_check)