Skip to content

Commit 6b02ab5

Browse files
committed
[TKW] Add support for tkw.round_even
Signed-off-by: Ege Beysel <beysel@roofline.ai>
1 parent d759cb5 commit 6b02ab5

File tree

3 files changed

+24
-4
lines changed

3 files changed

+24
-4
lines changed

iree/turbine/kernel/ops/wave_ops.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,10 @@ def reshape(inputs: Sequence["Register"]) -> "Register":
156156
...
157157

158158

159+
def round_even(src: "Register") -> "Register":
160+
...
161+
162+
159163
def define_op(op_name: str) -> Callable[[T], T]:
160164
def decorator(cls: T) -> T:
161165
cls.tkw_op_name = op_name
@@ -704,6 +708,7 @@ def infer_type(self):
704708
@define_interface_op("exp2")
705709
@define_interface_op("reciprocal")
706710
@define_interface_op("abs")
711+
@define_interface_op("round_even")
707712
@define_py_op(operator.neg)
708713
@dataclass
709714
class UnaryPyOp(CustomOp, ABC):

iree/turbine/kernel/wave/codegen.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
cast,
7474
permute,
7575
reshape,
76+
round_even,
7677
)
7778
from ..lang.wave_types import IndexMapping, IndexSymbol
7879
from ..compiler.base import CodegenError, ValidationError, NDEBUG
@@ -1197,6 +1198,16 @@ def handle_abs(source: Value) -> OpResult:
11971198
return abs
11981199

11991200

1201+
@handle_unary_op(round_even)
1202+
def handle_round_even(source: Value) -> OpResult:
1203+
element_type = get_type_or_element_type(source.type)
1204+
if _is_float_type(element_type):
1205+
round_even = math_d.roundeven(source)
1206+
else:
1207+
raise ValidationError(f"Found unhandled operand type for abs: {element_type}")
1208+
return round_even
1209+
1210+
12001211
###############################################################################
12011212
# Control Flow ops
12021213
###############################################################################

lit_tests/kernel/wave/codegen.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,7 @@ def test(
725725
res = tkw.reciprocal(res)
726726
res = tkw.abs(res)
727727
res_b = tkw.abs(b_reg)
728+
res = tkw.round_even(res)
728729
tkw.write(res, a, elements_per_thread=4)
729730
tkw.write(res_b, b, elements_per_thread=4)
730731

@@ -740,12 +741,15 @@ def test(
740741
# CHECK: %[[EXP2:.+]] = math.exp2 %[[NEG]]
741742

742743
# Testing reciprocal
743-
# %[[ONES:.+]] = arith.constant dense<1.000000e+00> : vector<4xf16>
744-
# %[[RECIPROCAL:.+]] = arith.divf %[[ONES]], %[[EXP2]] : vector<4xf16>
744+
# CHECK: %[[ONES:.+]] = arith.constant dense<1.000000e+00> : vector<4xf16>
745+
# CHECK: %[[RECIPROCAL:.+]] = arith.divf %[[ONES]], %[[EXP2]] : vector<4xf16>
745746

746747
# Testing abs
747-
# %[[ABSF:.+]] = math.absf %[[RECIPROCAL]]
748-
# %[[ABSI:.+]] = math.absi
748+
# CHECK: %[[ABSF:.+]] = math.absf %[[RECIPROCAL]]
749+
# CHECK: %[[ABSI:.+]] = math.absi
750+
751+
# Testing round_even
752+
# CHECK: %[[ROUNDEVEN:.+]] = math.roundeven %[[ABSF]]
749753

750754

751755
@run_test

0 commit comments

Comments
 (0)