File tree Expand file tree Collapse file tree 3 files changed +24
-4
lines changed Expand file tree Collapse file tree 3 files changed +24
-4
lines changed Original file line number Diff line number Diff line change @@ -156,6 +156,10 @@ def reshape(inputs: Sequence["Register"]) -> "Register":
156
156
...
157
157
158
158
159
+ def round_even (src : "Register" ) -> "Register" :
160
+ ...
161
+
162
+
159
163
def define_op (op_name : str ) -> Callable [[T ], T ]:
160
164
def decorator (cls : T ) -> T :
161
165
cls .tkw_op_name = op_name
@@ -704,6 +708,7 @@ def infer_type(self):
704
708
@define_interface_op ("exp2" )
705
709
@define_interface_op ("reciprocal" )
706
710
@define_interface_op ("abs" )
711
+ @define_interface_op ("round_even" )
707
712
@define_py_op (operator .neg )
708
713
@dataclass
709
714
class UnaryPyOp (CustomOp , ABC ):
Original file line number Diff line number Diff line change 73
73
cast ,
74
74
permute ,
75
75
reshape ,
76
+ round_even ,
76
77
)
77
78
from ..lang .wave_types import IndexMapping , IndexSymbol
78
79
from ..compiler .base import CodegenError , ValidationError , NDEBUG
@@ -1197,6 +1198,16 @@ def handle_abs(source: Value) -> OpResult:
1197
1198
return abs
1198
1199
1199
1200
1201
+ @handle_unary_op (round_even )
1202
+ def handle_round_even (source : Value ) -> OpResult :
1203
+ element_type = get_type_or_element_type (source .type )
1204
+ if _is_float_type (element_type ):
1205
+ round_even = math_d .roundeven (source )
1206
+ else :
1207
+ raise ValidationError (f"Found unhandled operand type for abs: { element_type } " )
1208
+ return round_even
1209
+
1210
+
1200
1211
###############################################################################
1201
1212
# Control Flow ops
1202
1213
###############################################################################
Original file line number Diff line number Diff line change @@ -725,6 +725,7 @@ def test(
725
725
res = tkw .reciprocal (res )
726
726
res = tkw .abs (res )
727
727
res_b = tkw .abs (b_reg )
728
+ res = tkw .round_even (res )
728
729
tkw .write (res , a , elements_per_thread = 4 )
729
730
tkw .write (res_b , b , elements_per_thread = 4 )
730
731
@@ -740,12 +741,15 @@ def test(
740
741
# CHECK: %[[EXP2:.+]] = math.exp2 %[[NEG]]
741
742
742
743
# Testing reciprocal
743
- # %[[ONES:.+]] = arith.constant dense<1.000000e+00> : vector<4xf16>
744
- # %[[RECIPROCAL:.+]] = arith.divf %[[ONES]], %[[EXP2]] : vector<4xf16>
744
+ # CHECK: %[[ONES:.+]] = arith.constant dense<1.000000e+00> : vector<4xf16>
745
+ # CHECK: %[[RECIPROCAL:.+]] = arith.divf %[[ONES]], %[[EXP2]] : vector<4xf16>
745
746
746
747
# Testing abs
747
- # %[[ABSF:.+]] = math.absf %[[RECIPROCAL]]
748
- # %[[ABSI:.+]] = math.absi
748
+ # CHECK: %[[ABSF:.+]] = math.absf %[[RECIPROCAL]]
749
+ # CHECK: %[[ABSI:.+]] = math.absi
750
+
751
+ # Testing round_even
752
+ # CHECK: %[[ROUNDEVEN:.+]] = math.roundeven %[[ABSF]]
749
753
750
754
751
755
@run_test
You can’t perform that action at this time.
0 commit comments