Skip to content

Commit 5ea86fa

Browse files
Add tests.
1 parent a7ad039 commit 5ea86fa

File tree

10 files changed

+183
-24
lines changed

10 files changed

+183
-24
lines changed

python/shark_turbine/ops/iree.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@ def _emit_tensor_trace(kb: KernelBuilder, key: str, ts: list[Value]):
4343

4444
@CustomOp.register(library=IREE_LIBRARY)
4545
class trace_tensor(CustomOp):
46-
name = "trace_tensor"
47-
signature = "(str trace_key, Tensor tensor) -> ()"
46+
signature = "trace_tensor(str trace_key, Tensor tensor) -> ()"
4847

4948
def select(self, ksel: KernelSelection):
5049
ksel.attr_str(0)
@@ -57,8 +56,7 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder):
5756

5857
@CustomOp.register(library=IREE_LIBRARY)
5958
class trace_tensors(CustomOp):
60-
name = "trace_tensors"
61-
signature = "(str trace_key, Tensor[] tensors) -> ()"
59+
signature = "trace_tensors(str trace_key, Tensor[] tensors) -> ()"
6260

6361
def select(self, ksel: KernelSelection):
6462
ksel.attr_str(0)

python/shark_turbine/runtime/op_reg/base.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from abc import ABC, abstractmethod, abstractproperty
1414
import functools
1515
import logging
16+
import re
1617
import textwrap
1718

1819
import torch
@@ -132,8 +133,8 @@ def __init__(
132133
register_meta: bool,
133134
register_impl: bool,
134135
):
135-
name = self.name
136-
fq_schema = f"{name}{self.signature}"
136+
fq_schema = self.signature
137+
name = _extract_name_from_signature(fq_schema)
137138
library.define(fq_schema)
138139
self.library = library
139140
self.cache_key_base = f"{library.ns}.{library.kind}::{name}"
@@ -156,17 +157,15 @@ def __init__(
156157
fq_name = f"{library.ns}.{name}"
157158
ALL_CUSTOM_OP_REGS[fq_name] = self
158159

159-
@abstractproperty
160-
def name(self) -> str:
161-
"""Name of the operation."""
162-
...
163-
164160
@abstractproperty
165161
def signature(self) -> str:
166162
"""PyTorch function signature.
167163
168-
This excludes the name, which will come from the `name` property
169-
and be prepended to make a full PyTorch schema.
164+
This is in the normal PyTorch kernel registration form. For example:
165+
166+
```
167+
my_op(Tensor t) -> Tensor
168+
```
170169
"""
171170
...
172171

@@ -772,3 +771,13 @@ def handler(*args):
772771
return eager_dispatch(ksel)
773772

774773
return handler
774+
775+
776+
_SIGNATURE_NAME_PATTERN = re.compile(r"^([^(]+)")
777+
778+
779+
def _extract_name_from_signature(sig: str) -> str:
780+
m = re.match(_SIGNATURE_NAME_PATTERN, sig)
781+
if not m:
782+
raise ValueError(f"Expected signature of form `name() -> (). Got: {sig}")
783+
return m.group(1)

python/shark_turbine/transforms/general/custom_op_expansion.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,11 @@
4343

4444

4545
class ExpandCustomOpsPass(Pass):
46-
def __init__(self, root_op: Operation):
46+
def __init__(
47+
self, root_op: Operation, reg: dict[str, CustomOp] = ALL_CUSTOM_OP_REGS
48+
):
4749
super().__init__(root_op)
50+
self.reg = reg
4851
# Track pending deletions in a dict to preserve order and unique.
4952
self.ops_to_delete: dict[Operation, None] = {}
5053
self.type_converter = NativeTypeConverter(root_op.context)
@@ -76,7 +79,7 @@ def expand_func(self, func_op: Operation):
7679
custom_op_name = StringAttr(op.attributes["name"]).value
7780
if custom_op_name.startswith(name_prefix):
7881
local_name = custom_op_name[len(name_prefix) :]
79-
custom_op_reg = ALL_CUSTOM_OP_REGS.get(local_name)
82+
custom_op_reg = self.reg.get(local_name)
8083
if custom_op_reg is not None:
8184
self.expand_custom_op(custom_op_reg, op)
8285

@@ -172,7 +175,9 @@ def attr_str(self, arg: int) -> AttrArg:
172175
return desc
173176

174177
def return_tensor(self, t: Tensor) -> TensorArg:
175-
raise NotImplementedError("NYI: return_tensor")
178+
desc = TensorArg(t)
179+
self.result_descs.append(desc)
180+
return desc
176181

177182

178183
def _get_constant_str_from_value(v: Value) -> str:
@@ -224,6 +229,7 @@ def __init__(
224229
)
225230
self.location = location
226231
self.torch_op = torch_op
232+
self.type_converter = type_converter
227233

228234
def yield_results(self, *results: Value):
229235
"""Yields results of the kernel computation."""
@@ -234,5 +240,9 @@ def yield_results(self, *results: Value):
234240
torch_op_results
235241
), f"Mismatched yield_results with custom op results"
236242
for new_result, old_result in zip(results, torch_op_results):
243+
torch_type = old_result.type
244+
new_result = self.type_converter.materialize_native_to_torch(
245+
new_result, torch_type
246+
)
237247
old_result.replace_all_uses_with(new_result)
238248
self.yielded = True

tests/dynamo/type_conversion_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,19 @@ def testPrimitives(self):
2424
self._compareNative("!torch.int", "i64")
2525
self._compareNative("!torch.float", "f64")
2626

27+
def testSigned(self):
28+
self._compareNative("!torch.bool", "i1", signless=False)
29+
self._compareNative("!torch.int", "si64", signless=False)
30+
2731
def testValueTensors(self):
2832
self._compareNative("!torch.vtensor<[2, 2],f32>", "tensor<2x2xf32>")
2933
self._compareNative("!torch.vtensor<[?, ?],f32>", "tensor<?x?xf32>")
3034
self._compareNative("!torch.vtensor<[],f32>", "tensor<f32>")
3135

32-
def _compareNative(self, torch_str: str, native_str: str):
36+
def _compareNative(self, torch_str: str, native_str: str, *, signless: bool = True):
3337
with self.conv._context:
3438
torch_type = IrType.parse(torch_str)
35-
native_type = self.conv.torch_type_to_native(torch_type)
39+
native_type = self.conv.torch_type_to_native(torch_type, signless=signless)
3640
self.assertEqual(str(native_type), native_str)
3741

3842

tests/runtime/op_reg/kernel_aot_test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,15 @@ class KernelRegTest(unittest.TestCase):
4343
def testTrace(self):
4444
mlp = MLP()
4545
prog = aot.export(mlp, torch.empty(97, 8, dtype=torch.float32))
46-
# print("ORIGINAL EXPORTED:")
47-
# print(prog.print_readable())
4846

4947
p = ExpandCustomOpsPass(prog.mlir_module)
5048
p.run()
5149

5250
print("CUSTOM OP CONVERTED:")
53-
print(prog.mlir_module)
51+
module_asm = str(prog.mlir_module)
52+
self.assertIn('flow.tensor.trace "LAYER0"', module_asm)
53+
self.assertIn('flow.tensor.trace "LAYER1"', module_asm)
54+
self.assertIn('flow.tensor.trace "LAYER3"', module_asm)
5455

5556
def testEager(self):
5657
mlp = MLP()

tests/runtime/op_reg/kernel_reg_test.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,10 @@
1515

1616

1717
class KernelRegTest(unittest.TestCase):
18-
def testSimple(self):
18+
def testRegistrationDispatchAndCache(self):
1919
@CustomOp.register
2020
class identity(CustomOp):
21-
name = "test_identity"
22-
signature = "(Tensor self) -> Tensor"
21+
signature = "test_identity(Tensor self) -> Tensor"
2322

2423
def select(self, ksel: KernelSelection):
2524
x = ksel.arg_tensor(0)
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
import logging
8+
from pathlib import Path
9+
import torch
10+
import unittest
11+
12+
from shark_turbine.transforms.general.custom_op_expansion import ExpandCustomOpsPass
13+
from shark_turbine.runtime.op_reg import (
14+
def_library,
15+
CustomOp,
16+
KernelBuilder,
17+
KernelSelection,
18+
)
19+
20+
from shark_turbine.support.ir_imports import (
21+
Context,
22+
Module,
23+
)
24+
25+
26+
class PassTest(unittest.TestCase):
27+
@classmethod
28+
def setUpClass(cls):
29+
cls.lib = def_library("expand_custom_op_pass_test")
30+
CustomOp.register(library=cls.lib)(IdentityOp)
31+
CustomOp.register(library=cls.lib)(PrintStringAttrOp)
32+
CustomOp.register(library=cls.lib)(IntArgOp)
33+
34+
def testTensorArgReturn(self):
35+
m = self.run_test_case("custom_op_simple.mlir")
36+
m_asm = str(m)
37+
self.assertNotIn("torch.operator", m_asm)
38+
self.assertIn(
39+
"%0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[97,8],f32> -> tensor<97x8xf32>",
40+
m_asm,
41+
)
42+
self.assertIn(
43+
"%1 = torch_c.from_builtin_tensor %0 : tensor<97x8xf32> -> !torch.vtensor<[97,8],f32>",
44+
m_asm,
45+
)
46+
print(m_asm)
47+
48+
def testStringAttrArg(self):
49+
global _TEST_STRING_ATTR
50+
_TEST_STRING_ATTR = ""
51+
m = self.run_test_case("custom_op_string_attr.mlir")
52+
m_asm = str(m)
53+
self.assertEqual(_TEST_STRING_ATTR, "TEST_VALUE")
54+
self.assertNotIn("torch.operator", m_asm)
55+
print(m_asm)
56+
57+
def testIntArg(self):
58+
global _TEST_STRING_ATTR
59+
_TEST_STRING_ATTR = ""
60+
with self.assertRaisesRegex(NotImplementedError, "arg_int"):
61+
self.run_test_case("custom_op_int_arg.mlir")
62+
63+
def run_test_case(self, file_name: str):
64+
p = Path(__file__).resolve().parent / "testdata" / file_name
65+
contents = p.read_text()
66+
with Context() as ctx:
67+
m = Module.parse(contents)
68+
p = ExpandCustomOpsPass(m.operation)
69+
p.run()
70+
print(f"TEST CASE {file_name}:\n{m}")
71+
m.operation.verify()
72+
return m
73+
74+
75+
class IdentityOp(CustomOp):
76+
signature = "identity_tensor(Tensor t) -> Tensor"
77+
78+
def select(self, ksel: KernelSelection):
79+
x = ksel.arg_tensor(0)
80+
ksel.return_tensor(x.t)
81+
82+
def generate(self, ksel: KernelSelection, kb: KernelBuilder):
83+
kb.yield_results(kb.arg_bindings[0])
84+
85+
86+
class PrintStringAttrOp(CustomOp):
87+
signature = "print_string_attr(str key) -> ()"
88+
89+
def select(self, ksel: KernelSelection):
90+
ksel.attr_str(0)
91+
92+
def generate(self, ksel: KernelSelection, kb: KernelBuilder):
93+
global _TEST_STRING_ATTR
94+
_TEST_STRING_ATTR = str(ksel.arg_descs[0].v)
95+
print("CAPTURED STRING ATTR:", _TEST_STRING_ATTR)
96+
kb.yield_results()
97+
98+
99+
class IntArgOp(CustomOp):
100+
signature = "int_arg(int t) -> ()"
101+
102+
def select(self, ksel: KernelSelection):
103+
x = ksel.arg_int(0)
104+
ksel.return_int()
105+
106+
def generate(self, ksel: KernelSelection, kb: KernelBuilder):
107+
kb.yield_results(kb.arg_bindings[0])
108+
109+
110+
if __name__ == "__main__":
111+
logging.basicConfig(level=logging.DEBUG)
112+
unittest.main()
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
builtin.module {
2+
3+
func.func @forward() {
4+
%i = torch.constant.int 1000
5+
torch.operator "torch.expand_custom_op_pass_test.int_arg"(%i) : (!torch.int) -> ()
6+
return
7+
}
8+
9+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
builtin.module {
2+
3+
func.func @forward(%arg0: !torch.vtensor<[97,8],f32>) -> !torch.vtensor<[97,8],f32> {
4+
%0 = torch.operator "torch.expand_custom_op_pass_test.identity_tensor"(%arg0) : (!torch.vtensor<[97,8],f32>) -> (!torch.vtensor<[97,8],f32>)
5+
return %0 : !torch.vtensor<[97,8],f32>
6+
}
7+
8+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
builtin.module {
2+
3+
func.func @forward() {
4+
%str = torch.constant.str "TEST_VALUE"
5+
torch.operator "torch.expand_custom_op_pass_test.print_string_attr"(%str) : (!torch.str) -> ()
6+
return
7+
}
8+
9+
}

0 commit comments

Comments
 (0)