Skip to content

Commit

Permalink
Fix bug with block pointer multi dim args (pytorch#120263)
Browse files Browse the repository at this point in the history
Summary:
Now we can parse statements like
```
%22 = tt.make_tensor_ptr %20, [%21, %c128_i64], [%c2048_i64, %c1_i64], [%1, %c0_i32]
```

Test Plan:
Added new test

```
buck2 test mode/opt //hammer/ops/tests/inductor:ragged_hstu_test
```
now passes again with optimizations

Differential Revision: D53975130

Pull Request resolved: pytorch#120263
Approved by: https://github.com/aakhundov, https://github.com/sijiac
  • Loading branch information
oulgen authored and pytorchmergebot committed Feb 21, 2024
1 parent 3cd6a21 commit eae025b
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 2 deletions.
11 changes: 11 additions & 0 deletions test/dynamo/test_triton_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,6 +1352,7 @@ def fwd_kernel(

if HAS_CUDA and HAS_LARK:
t = torch.randn(4)
tt = torch.randn(4, 1)
tests = [
[
add_kernel,
Expand Down Expand Up @@ -1419,6 +1420,16 @@ def fwd_kernel(
},
["output_ptr"],
],
[
kernel_with_block_ptr_2d,
{
"x_ptr": tt,
"output_ptr": tt,
"n_elements": 4,
"BLOCK_SIZE": 4,
},
["output_ptr"],
],
[
add_kernel_with_import,
{
Expand Down
11 changes: 9 additions & 2 deletions torch/_higher_order_ops/triton_kernel_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def parse_ttir(ttir, kwargs):
| INTERMEDIATE_CONSTANT
| CONSTANT
| PARAM
| "[" arg "]"
| "[" args "]"
| arg_with_index
?arg_with_index: arg "#" DIGIT+
Expand Down Expand Up @@ -251,7 +251,14 @@ def parse_ttir(ttir, kwargs):
def convert(token):
if isinstance(token, lark.tree.Tree):
if token.data == "args":
return [convert(a) for a in token.children]
res = []
for a in token.children:
c = convert(a)
if isinstance(c, list):
res.extend(c)
else:
res.append(c)
return res
elif token.data in {"assign_lhs", "arg_with_index"}:
# Drop length/index qualifier
return convert(token.children[0])
Expand Down
34 changes: 34 additions & 0 deletions torch/testing/_internal/triton_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,40 @@ def add_kernel_with_block_ptr(
boundary_check=[0],
)

@triton.jit
def kernel_with_block_ptr_2d(
x_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
x = tl.load(
tl.make_block_ptr(
base=x_ptr,
shape=[n_elements, 1],
strides=[1, 1],
offsets=[block_start, 0],
block_shape=[BLOCK_SIZE, 1],
order=[1, 0],
),
boundary_check=[0],
)
output = x
tl.store(
tl.make_block_ptr(
base=output_ptr,
shape=[n_elements, 1],
strides=[1, 1],
offsets=[block_start, 0],
block_shape=[BLOCK_SIZE, 1],
order=[1, 0],
),
output,
boundary_check=[0],
)

from triton.language import load, store

@triton.jit
Expand Down

0 comments on commit eae025b

Please sign in to comment.