-
Notifications
You must be signed in to change notification settings - Fork 10
Fix StreamK Kernel Bias #67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -61,11 +61,15 @@ | |
| # Whether to test with torch.compile | ||
| USE_COMPILE = [False, True] | ||
|
|
||
| # Whether to enable StreamK (vs. Persistent path) | ||
| ENABLE_STREAMK = [False, True] | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("use_compile", USE_COMPILE) | ||
| @pytest.mark.parametrize("enable_streamk", ENABLE_STREAMK) | ||
| @pytest.mark.parametrize("m, n, k", STANDARD_DIMS + EDGE_CASE_DIMS) | ||
| @pytest.mark.parametrize("dtype", DTYPES) | ||
| def test_addmm_forward_correctness(m, n, k, dtype, use_compile): | ||
| def test_addmm_forward_correctness(m, n, k, dtype, enable_streamk, use_compile): | ||
| """Test that tritonblas.addmm forward pass matches torch.addmm.""" | ||
| torch.manual_seed(42) | ||
|
|
||
|
|
@@ -78,7 +82,7 @@ def test_addmm_forward_correctness(m, n, k, dtype, use_compile): | |
| addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) | ||
|
|
||
| # tritonblas result | ||
| result = addmm_fn(bias, a, b) | ||
| result = addmm_fn(bias, a, b, enable_streamk=enable_streamk) | ||
|
|
||
| # torch reference | ||
| expected = torch.addmm(bias, a, b) | ||
|
|
@@ -88,9 +92,10 @@ def test_addmm_forward_correctness(m, n, k, dtype, use_compile): | |
|
|
||
|
|
||
| @pytest.mark.parametrize("use_compile", USE_COMPILE) | ||
| @pytest.mark.parametrize("enable_streamk", ENABLE_STREAMK) | ||
| @pytest.mark.parametrize("m, n, k", STANDARD_DIMS + EDGE_CASE_DIMS) | ||
| @pytest.mark.parametrize("dtype", DTYPES) | ||
| def test_addmm_backward_correctness(m, n, k, dtype, use_compile): | ||
| def test_addmm_backward_correctness(m, n, k, dtype, enable_streamk, use_compile): | ||
| """Test that tritonblas.addmm backward pass produces correct gradients.""" | ||
| torch.manual_seed(42) | ||
|
|
||
|
|
@@ -109,7 +114,7 @@ def test_addmm_backward_correctness(m, n, k, dtype, use_compile): | |
| addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) | ||
|
|
||
| # Forward pass | ||
| result = addmm_fn(bias, a, b) | ||
| result = addmm_fn(bias, a, b, enable_streamk=enable_streamk) | ||
| result_ref = torch.addmm(bias_ref, a_ref, b_ref) | ||
|
|
||
| # Backward pass with same upstream gradient | ||
|
|
@@ -127,9 +132,10 @@ def test_addmm_backward_correctness(m, n, k, dtype, use_compile): | |
|
|
||
|
|
||
| @pytest.mark.parametrize("use_compile", USE_COMPILE) | ||
| @pytest.mark.parametrize("enable_streamk", ENABLE_STREAMK) | ||
| @pytest.mark.parametrize("m, n, k", SKINNY_DIMS) | ||
| @pytest.mark.parametrize("dtype", DTYPES) | ||
| def test_addmm_skinny_matrices(m, n, k, dtype, use_compile): | ||
| def test_addmm_skinny_matrices(m, n, k, dtype, enable_streamk, use_compile): | ||
| """Test addmm with skinny matrices (large K dimension).""" | ||
| torch.manual_seed(42) | ||
|
|
||
|
|
@@ -146,7 +152,7 @@ def test_addmm_skinny_matrices(m, n, k, dtype, use_compile): | |
| addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) | ||
|
|
||
| # Forward | ||
| result = addmm_fn(bias, a, b) | ||
| result = addmm_fn(bias, a, b, enable_streamk=enable_streamk) | ||
| result_ref = torch.addmm(bias_ref, a_ref, b_ref) | ||
|
|
||
| torch.testing.assert_close(result, result_ref, atol=1e-1, rtol=1e-1) | ||
|
|
@@ -161,7 +167,8 @@ def test_addmm_skinny_matrices(m, n, k, dtype, use_compile): | |
|
|
||
|
|
||
| @pytest.mark.parametrize("use_compile", USE_COMPILE) | ||
| def test_addmm_inplace_with_grad_raises(use_compile): | ||
| @pytest.mark.parametrize("enable_streamk", ENABLE_STREAMK) | ||
| def test_addmm_inplace_with_grad_raises(enable_streamk, use_compile): | ||
| """Test that addmm with out=... raises RuntimeError when autograd is enabled.""" | ||
| torch.manual_seed(42) | ||
| m, n, k = 64, 64, 64 | ||
|
|
@@ -177,11 +184,12 @@ def test_addmm_inplace_with_grad_raises(use_compile): | |
| addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) | ||
|
|
||
| with pytest.raises(RuntimeError, match="don't support automatic differentiation"): | ||
| addmm_fn(bias, a, b, out=out) | ||
| addmm_fn(bias, a, b, out=out, enable_streamk=enable_streamk) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("use_compile", USE_COMPILE) | ||
| def test_addmm_inplace_without_grad_works(use_compile): | ||
| @pytest.mark.parametrize("enable_streamk", ENABLE_STREAMK) | ||
| def test_addmm_inplace_without_grad_works(enable_streamk, use_compile): | ||
| """Test that addmm with out=... works when autograd is disabled.""" | ||
| torch.manual_seed(42) | ||
| m, n, k = 64, 64, 64 | ||
|
|
@@ -198,7 +206,7 @@ def test_addmm_inplace_without_grad_works(use_compile): | |
|
|
||
| # Should work with torch.no_grad() | ||
| with torch.no_grad(): | ||
| result = addmm_fn(bias, a, b, out=out) | ||
| result = addmm_fn(bias, a, b, out=out, enable_streamk=enable_streamk) | ||
|
|
||
| # In-place path returns None (custom ops don't support aliasing) | ||
| assert result is None, "in-place addmm should return None" | ||
|
|
@@ -209,7 +217,8 @@ def test_addmm_inplace_without_grad_works(use_compile): | |
|
|
||
|
|
||
| @pytest.mark.parametrize("use_compile", USE_COMPILE) | ||
| def test_addmm_inplace_output_correctness(use_compile): | ||
| @pytest.mark.parametrize("enable_streamk", ENABLE_STREAMK) | ||
| def test_addmm_inplace_output_correctness(enable_streamk, use_compile): | ||
| """Test that addmm in-place mode produces correct results.""" | ||
| torch.manual_seed(42) | ||
| m, n, k = 128, 256, 512 | ||
|
|
@@ -225,14 +234,15 @@ def test_addmm_inplace_output_correctness(use_compile): | |
| addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) | ||
|
|
||
| with torch.no_grad(): | ||
| addmm_fn(bias, a, b, out=out) | ||
| addmm_fn(bias, a, b, out=out, enable_streamk=enable_streamk) | ||
|
|
||
| expected = torch.addmm(bias, a, b) | ||
| torch.testing.assert_close(out, expected, atol=1e-1, rtol=1e-1) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("use_compile", USE_COMPILE) | ||
| def test_addmm_no_grad_tensors(use_compile): | ||
| @pytest.mark.parametrize("enable_streamk", ENABLE_STREAMK) | ||
| def test_addmm_no_grad_tensors(enable_streamk, use_compile): | ||
| """Test addmm works when input tensors don't require grad.""" | ||
| torch.manual_seed(42) | ||
| m, n, k = 64, 64, 64 | ||
|
|
@@ -246,14 +256,15 @@ def test_addmm_no_grad_tensors(use_compile): | |
| if use_compile: | ||
| addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) | ||
|
|
||
| result = addmm_fn(bias, a, b) | ||
| result = addmm_fn(bias, a, b, enable_streamk=enable_streamk) | ||
| expected = torch.addmm(bias, a, b) | ||
|
|
||
| torch.testing.assert_close(result, expected, atol=1e-1, rtol=1e-1) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("use_compile", USE_COMPILE) | ||
| def test_addmm_partial_grad(use_compile): | ||
| @pytest.mark.parametrize("enable_streamk", ENABLE_STREAMK) | ||
| def test_addmm_partial_grad(enable_streamk, use_compile): | ||
| """Test addmm when only some inputs require grad.""" | ||
| torch.manual_seed(42) | ||
| m, n, k = 64, 64, 64 | ||
|
|
@@ -272,45 +283,10 @@ def test_addmm_partial_grad(use_compile): | |
| if use_compile: | ||
| addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) | ||
|
|
||
| result = addmm_fn(bias, a, b) | ||
| result_ref = torch.addmm(bias_ref, a_ref, b_ref) | ||
|
|
||
| result.sum().backward() | ||
| result_ref.sum().backward() | ||
|
|
||
| torch.testing.assert_close(a.grad, a_ref.grad, atol=1e-1, rtol=1e-1) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("use_compile", USE_COMPILE) | ||
| @pytest.mark.parametrize("enable_streamk", [False, True]) | ||
| def test_addmm_streamk_modes(enable_streamk, use_compile): | ||
| """Test addmm with different streamk settings.""" | ||
| torch.manual_seed(42) | ||
| m, n, k = 256, 256, 256 | ||
| dtype = torch.bfloat16 | ||
|
|
||
| a = torch.randn(m, k, device='cuda', dtype=dtype, requires_grad=True) | ||
| b = torch.randn(k, n, device='cuda', dtype=dtype, requires_grad=True) | ||
| bias = torch.randn(n, device='cuda', dtype=dtype, requires_grad=True) | ||
|
|
||
| a_ref = a.detach().clone().requires_grad_(True) | ||
| b_ref = b.detach().clone().requires_grad_(True) | ||
| bias_ref = bias.detach().clone().requires_grad_(True) | ||
|
|
||
| addmm_fn = tritonblas.addmm | ||
| if use_compile: | ||
| addmm_fn = torch.compile(tritonblas.addmm, fullgraph=True) | ||
|
|
||
| # Forward | ||
| result = addmm_fn(bias, a, b, enable_streamk=enable_streamk) | ||
| result_ref = torch.addmm(bias_ref, a_ref, b_ref) | ||
|
|
||
| torch.testing.assert_close(result, result_ref, atol=1e-1, rtol=1e-1) | ||
|
|
||
| # Backward | ||
| result.sum().backward() | ||
| result_ref.sum().backward() | ||
|
|
||
| torch.testing.assert_close(a.grad, a_ref.grad, atol=1e-1, rtol=1e-1) | ||
|
Comment on lines
286
to
292
|
||
| torch.testing.assert_close(b.grad, b_ref.grad, atol=1e-1, rtol=1e-1) | ||
| torch.testing.assert_close(bias.grad, bias_ref.grad, atol=1e-1, rtol=1e-1) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this Stream-K section, bias is (now correctly) loaded earlier as a full
BLOCK_SIZE_Nvector, but this block re-loads bias twice (bias_left/bias_right) and also introduces% Nplus a redundant mask (after modulo, indices are already in-range for N>0). To reduce global memory traffic and simplify the indexing, consider reusing the already-loadedbiasvector and slicing/splitting it into left/right halves for broadcasting, rather than recomputing indices and issuing additionaltl.loadoperations.