Skip to content

Commit

Permalink
Add if __name__ == "__main__" to softmax.py
Browse files Browse the repository at this point in the history
  • Loading branch information
voltjia committed Jan 11, 2025
1 parent ec7bc2d commit a8bfbd6
Showing 1 changed file with 44 additions and 45 deletions.
89 changes: 44 additions & 45 deletions softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,53 +71,52 @@ def triton_softmax(input):
return output


torch.manual_seed(0)
input = torch.randn(1823, 781, dtype=torch.float16, device="cuda")
ninetoothed_output = softmax(input)
torch_output = torch.softmax(input, axis=-1)
triton_output = triton_softmax(input)
print(ninetoothed_output)
print(torch_output)
print(triton_output)
if torch.allclose(ninetoothed_output, torch_output, atol=1e-5):
print("✅ NineToothed and PyTorch match.")
else:
print("❌ NineToothed and PyTorch differ.")
if torch.allclose(ninetoothed_output, triton_output):
print("✅ NineToothed and Triton match.")
else:
print("❌ NineToothed and Triton differ.")


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["n"],
x_vals=[128 * i for i in range(2, 100)],
line_arg="provider",
line_vals=["ninetoothed", "torch", "triton"],
line_names=["NineToothed", "PyTorch", "Triton"],
styles=[("blue", "-"), ("green", "-"), ("orange", "-")],
ylabel="GB/s",
plot_name="softmax-performance",
args={"m": 4096},
if __name__ == "__main__":
torch.manual_seed(0)
input = torch.randn(1823, 781, dtype=torch.float16, device="cuda")
ninetoothed_output = softmax(input)
torch_output = torch.softmax(input, axis=-1)
triton_output = triton_softmax(input)
print(ninetoothed_output)
print(torch_output)
print(triton_output)
if torch.allclose(ninetoothed_output, torch_output, atol=1e-5):
print("✅ NineToothed and PyTorch match.")
else:
print("❌ NineToothed and PyTorch differ.")
if torch.allclose(ninetoothed_output, triton_output):
print("✅ NineToothed and Triton match.")
else:
print("❌ NineToothed and Triton differ.")

@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["n"],
x_vals=[128 * i for i in range(2, 100)],
line_arg="provider",
line_vals=["ninetoothed", "torch", "triton"],
line_names=["NineToothed", "PyTorch", "Triton"],
styles=[("blue", "-"), ("green", "-"), ("orange", "-")],
ylabel="GB/s",
plot_name="softmax-performance",
args={"m": 4096},
)
)
)
def benchmark(m, n, provider):
input = torch.randn(m, n, device="cuda", dtype=torch.float16)
stream = torch.cuda.Stream()
torch.cuda.set_stream(stream)
def benchmark(m, n, provider):
input = torch.randn(m, n, device="cuda", dtype=torch.float16)
stream = torch.cuda.Stream()
torch.cuda.set_stream(stream)

if provider == "ninetoothed":
ms = triton.testing.do_bench(lambda: softmax(input))
elif provider == "torch":
ms = triton.testing.do_bench(lambda: torch.softmax(input, axis=-1))
elif provider == "triton":
ms = triton.testing.do_bench(lambda: triton_softmax(input))
if provider == "ninetoothed":
ms = triton.testing.do_bench(lambda: softmax(input))
elif provider == "torch":
ms = triton.testing.do_bench(lambda: torch.softmax(input, axis=-1))
elif provider == "triton":
ms = triton.testing.do_bench(lambda: triton_softmax(input))

def gbps(ms):
return 2 * input.numel() * input.element_size() * 1e-6 / ms
def gbps(ms):
return 2 * input.numel() * input.element_size() * 1e-6 / ms

return gbps(ms)
return gbps(ms)


benchmark.run(show_plots=True, print_data=True, save_path=".")
benchmark.run(show_plots=True, print_data=True, save_path=".")

0 comments on commit a8bfbd6

Please sign in to comment.