diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 8aad6dd116..942af4a09e 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -1948,6 +1948,7 @@ def var_mean( register_supported(PrimIDs.VAR_MEAN, var_mean, _var_mean_check) + def _matmul_check( a: TensorProxy, b: TensorProxy, @@ -1957,6 +1958,7 @@ def _matmul_check( enable_matmul = False return enable_matmul and is_supported_tensor(a) and is_supported_tensor(b) + def matmul( a: TensorProxy, b: TensorProxy, @@ -1968,6 +1970,7 @@ def matmul( nvb = getnv(b, fd, lc_to_nv_map) return fd.ops.matmul(nva, nvb) + register_supported(PrimIDs.MATMUL, matmul, _matmul_check) diff --git a/thunder/tests/test_nvfuser.py b/thunder/tests/test_nvfuser.py index 3b6bd6486b..a190bc0d3a 100644 --- a/thunder/tests/test_nvfuser.py +++ b/thunder/tests/test_nvfuser.py @@ -852,15 +852,17 @@ def get_num_fusions(cfn): nvfuserex.set_fuel(thunder.extend.FUEL_LEVEL.UNLIMITED) -@instantiate(dtypes=(thunder.float16, thunder.bfloat16), devicetypes=(devices.DeviceType.CUDA,), executors=(nvFuserExecutor,)) +@instantiate( + dtypes=(thunder.float16, thunder.bfloat16), devicetypes=(devices.DeviceType.CUDA,), executors=(nvFuserExecutor,) +) def test_matmul(executor, device: str, dtype: dtypes.dtype): m, n, k = 128, 64, 32 torch_dtype = ltorch.to_torch_dtype(dtype) a = torch.randn((m, k), dtype=torch_dtype, device=device) b = torch.randn((k, n), dtype=torch_dtype, device=device) - def fn(a , b): - return a.matmul(b); + def fn(a, b): + return a.matmul(b) compiled_func = thunder.compile( fn, @@ -872,4 +874,4 @@ def fn(a , b): traces = thunder.last_traces(compiled_func) fusions = examine.get_fusions(traces[-1]) assert len(fusions) == 1 - assert torch.allclose(out, torch.matmul(a, b)) \ No newline at end of file + assert torch.allclose(out, torch.matmul(a, b))