From 47eb3dc468ad33fe924696570aa5cae757f778b2 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 16 Apr 2024 21:17:09 +0000 Subject: [PATCH] enable nvfuser matmul --- thunder/tests/test_nvfuser.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/thunder/tests/test_nvfuser.py b/thunder/tests/test_nvfuser.py index 0df0cedd84..3b6bd6486b 100644 --- a/thunder/tests/test_nvfuser.py +++ b/thunder/tests/test_nvfuser.py @@ -850,3 +850,26 @@ def get_num_fusions(cfn): assert get_num_fusions(cfn_without_fusion) == 0 nvfuserex.set_fuel(thunder.extend.FUEL_LEVEL.UNLIMITED) + + +@instantiate(dtypes=(thunder.float16, thunder.bfloat16), devicetypes=(devices.DeviceType.CUDA,), executors=(nvFuserExecutor,)) +def test_matmul(executor, device: str, dtype: dtypes.dtype): + m, n, k = 128, 64, 32 + torch_dtype = ltorch.to_torch_dtype(dtype) + a = torch.randn((m, k), dtype=torch_dtype, device=device) + b = torch.randn((k, n), dtype=torch_dtype, device=device) + + def fn(a , b): + return a.matmul(b); + + compiled_func = thunder.compile( + fn, + executors_list=executor.executors_list(), + nv_enable_matmul=True, + ) + out = compiled_func(a, b) + + traces = thunder.last_traces(compiled_func) + fusions = examine.get_fusions(traces[-1]) + assert len(fusions) == 1 + assert torch.allclose(out, torch.matmul(a, b)) \ No newline at end of file