diff --git a/add.py b/add.py index 06e244c..c0829e1 100644 --- a/add.py +++ b/add.py @@ -59,8 +59,9 @@ def grid(meta): torch.manual_seed(0) size = 98432 -lhs = torch.rand(size, device="cuda") -rhs = torch.rand(size, device="cuda") +dtype = torch.float16 +lhs = torch.rand(size, dtype=dtype, device="cuda") +rhs = torch.rand(size, dtype=dtype, device="cuda") ninetoothed_output = add(lhs, rhs) torch_output = lhs + rhs triton_output = triton_add(lhs, rhs) @@ -92,8 +93,8 @@ def grid(meta): ) ) def benchmark(size, provider): - lhs = torch.rand(size, device="cuda", dtype=torch.float32) - rhs = torch.rand(size, device="cuda", dtype=torch.float32) + lhs = torch.rand(size, device="cuda", dtype=torch.float16) + rhs = torch.rand(size, device="cuda", dtype=torch.float16) quantiles = [0.5, 0.2, 0.8] if provider == "ninetoothed": diff --git a/conv2d.py b/conv2d.py index f17062d..434ecea 100644 --- a/conv2d.py +++ b/conv2d.py @@ -200,8 +200,9 @@ def grid(meta): torch.manual_seed(0) n, c, h, w = 4, 3, 224, 224 k, _, r, s = 8, c, 3, 3 - input = torch.randn(n, c, h, w, device="cuda") - filter = torch.randn(k, c, r, s, device="cuda") + dtype = torch.float16 + input = torch.randn(n, c, h, w, dtype=dtype, device="cuda") + filter = torch.randn(k, c, r, s, dtype=dtype, device="cuda") ninetoothed_output = conv2d(input, filter) torch_output = F.conv2d(input, filter) triton_output = triton_conv2d(input, filter) @@ -233,8 +234,9 @@ def grid(meta): def benchmark(h, w, provider): n, c, _, _ = 64, 3, h, w k, _, r, s = 64, c, 3, 3 - input = torch.randn((n, c, h, w), device="cuda") - filter = torch.randn((k, c, r, s), device="cuda") + dtype = torch.float16 + input = torch.randn((n, c, h, w), dtype=dtype, device="cuda") + filter = torch.randn((k, c, r, s), dtype=dtype, device="cuda") if provider == "ninetoothed": ms = triton.testing.do_bench(lambda: conv2d(input, filter)) diff --git a/softmax.py b/softmax.py index d42f316..533a627 100644 --- a/softmax.py +++ b/softmax.py @@ -72,14 +72,14 @@ def triton_softmax(input): torch.manual_seed(0) -input = torch.randn(1823, 781, device="cuda") +input = torch.randn(1823, 781, dtype=torch.float16, device="cuda") ninetoothed_output = softmax(input) torch_output = torch.softmax(input, axis=-1) triton_output = triton_softmax(input) print(ninetoothed_output) print(torch_output) print(triton_output) -if torch.allclose(ninetoothed_output, torch_output): +if torch.allclose(ninetoothed_output, torch_output, atol=1e-5): print("✅ NineToothed and PyTorch match.") else: print("❌ NineToothed and PyTorch differ.") @@ -103,7 +103,7 @@ def triton_softmax(input): ) ) def benchmark(m, n, provider): - input = torch.randn(m, n, device="cuda", dtype=torch.float32) + input = torch.randn(m, n, device="cuda", dtype=torch.float16) stream = torch.cuda.Stream() torch.cuda.set_stream(stream)