diff --git a/src/gt4py/backend/numpy_backend.py b/src/gt4py/backend/numpy_backend.py index 4b6c86d5e4..27aa0fd327 100644 --- a/src/gt4py/backend/numpy_backend.py +++ b/src/gt4py/backend/numpy_backend.py @@ -27,6 +27,7 @@ def __init__(self, *args, interval_k_start_name, interval_k_end_name, **kwargs): super().__init__(*args, **kwargs) self.interval_k_start_name = interval_k_start_name self.interval_k_end_name = interval_k_end_name + self.conditions_depth = 0 def _make_field_origin(self, name: str, origin=None): if origin is None: @@ -174,9 +175,8 @@ def visit_StencilImplementation(self, node: gt_ir.StencilImplementation): def visit_TernaryOpExpr(self, node: gt_ir.TernaryOpExpr): then_fmt = "({})" if isinstance(node.then_expr, gt_ir.CompositeExpr) else "{}" else_fmt = "({})" if isinstance(node.else_expr, gt_ir.CompositeExpr) else "{}" - # source = "np.vectorize(lambda cond, then_expr, else_expr:then_expr if cond else else_expr)({condition}, {then_expr}, {else_expr})".format( - source = "{np}.choose({condition},[{else_expr}, {then_expr}], out={np}.empty({np}.max(({np}.asanyarray({condition}).shape,{np}.asanyarray({then_expr}).shape,{np}.asanyarray({else_expr}).shape),axis=0), dtype={np}.{dtype}))".format( + source = "vectorized_ternary_op(condition={condition}, then_expr={then_expr}, else_expr={else_expr}, dtype={np}.{dtype})".format( condition=self.visit(node.condition), then_expr=then_fmt.format(self.visit(node.then_expr)), else_expr=else_fmt.format(self.visit(node.else_expr)), @@ -187,7 +187,59 @@ def visit_TernaryOpExpr(self, node: gt_ir.TernaryOpExpr): return source def visit_If(self, node: gt_ir.If): - raise NotImplementedError("The numpy backend does not support runtime if statements.") + sources = [] + self.conditions_depth += 1 + sources.append( + "__condition_{level} = {condition}".format( + level=self.conditions_depth, condition=self.visit(node.condition) + ) + ) + + stmts = [ + *[(True, stmt) for stmt in node.main_body.stmts], + *[(False, stmt) for stmt in node.else_body.stmts], + ] + + for is_if, stmt in stmts: + + if isinstance(stmt, gt_ir.Assign): + condition = ( + ( + "{np}.logical_and(".format(np=self.numpy_prefix) + + ", ".join( + [ + "__condition_{level}".format(level=i + 1) + for i in range(self.conditions_depth) + ] + ) + + ")" + ) + if self.conditions_depth > 1 + else "__condition_1" + ) + + target = self.visit(stmt.target) + value = self.visit(stmt.value) + sources.append( + "{target} = vectorized_ternary_op(condition={condition}, then_expr={then_expr}, else_expr={else_expr}, dtype={np}.{dtype})".format( + condition=condition, + target=target, + then_expr=value if is_if else target, + else_expr=target if is_if else value, + dtype=stmt.target.data_type.dtype.name, + np=self.numpy_prefix, + ) + ) + else: + stmt_sources = self.visit(stmt) + if isinstance(stmt_sources, list): + sources.extend(stmt_sources) + else: + sources.append(stmt_sources) + + self.conditions_depth -= 1 + # return "\n".join(sources) + return sources class NumPyGenerator(gt_backend.BaseGenerator): @@ -206,6 +258,27 @@ def __init__(self, backend_class, options): interval_k_end_name="interval_k_end", ) + def generate_module_members(self): + source = """ +def vectorized_ternary_op(*, condition, then_expr, else_expr, dtype): + return np.choose( + condition, + [else_expr, then_expr], + out=np.empty( + np.max( + ( + np.asanyarray(condition).shape, + np.asanyarray(then_expr).shape, + np.asanyarray(else_expr).shape, + ), + axis=0, + ), + dtype=dtype, + ), + ) +""" + return source + def generate_implementation(self): sources = gt_text.TextBlock(indent_size=self.TEMPLATE_INDENT_SIZE) self.source_generator(self.performance_ir, sources) diff --git a/src/gt4py/backend/python_generator.py b/src/gt4py/backend/python_generator.py index 4c442a8b5a..1adf97f6b9 100644 --- a/src/gt4py/backend/python_generator.py +++ b/src/gt4py/backend/python_generator.py @@ -280,4 +280,3 @@ def visit_StencilImplementation(self, node: gt_ir.StencilImplementation): for stage in group.stages: self.visit(stage, iteration_order=multi_stage.iteration_order) self.sources.append("") - diff --git a/tests/test_integration/test_suites.py b/tests/test_integration/test_suites.py index 1ac1873d0a..41763f5940 100644 --- a/tests/test_integration/test_suites.py +++ b/tests/test_integration/test_suites.py @@ -358,3 +358,103 @@ def validation(u, diffusion, *, weight, domain, origin, **kwargs): diffusion[...] = u[2:-2, 2:-2, :] - weight * ( flux_i[1:, :, :] - flux_i[:-1, :, :] + flux_j[:, 1:, :] - flux_j[:, :-1, :] ) + + +class TestRuntimeIfFlat(gt_testing.StencilTestSuite): + """Tests runtime ifs. + """ + + dtypes = (np.float_,) + domain_range = [(1, 15), (1, 15), (1, 15)] + backends = ["debug", "numpy", "gtx86"] + symbols = dict( + outfield=gt_testing.field(in_range=(-10, 10), boundary=[(0, 0), (0, 0), (0, 0)]), + ) + + def definition(outfield): + + with computation(PARALLEL), interval(...): + + if 1: + outfield = 1 + else: + outfield = 2 + + def validation(outfield, *, domain, origin, **kwargs): + outfield[...] = 1 + + +class TestRuntimeIfNested(gt_testing.StencilTestSuite): + """Tests nested runtime ifs + """ + + dtypes = (np.float_,) + domain_range = [(1, 15), (1, 15), (1, 15)] + backends = ["debug", "numpy", "gtx86"] + symbols = dict( + outfield=gt_testing.field(in_range=(-10, 10), boundary=[(0, 0), (0, 0), (0, 0)]), + ) + + def definition(outfield): + + with computation(PARALLEL), interval(...): + + if 1: + if 0: + outfield = 1 + else: + outfield = 2 + else: + outfield = 3 + + def validation(outfield, *, domain, origin, **kwargs): + outfield[...] = 2 + + +class TestRuntimeIfNestedDataDependent(gt_testing.StencilTestSuite): + """Tests nested runtime ifs, where not the same branch is used for all data points. + """ + + dtypes = (np.float_,) + domain_range = [(1, 15), (1, 15), (1, 15)] + backends = ["debug", "numpy", "gtx86"] + symbols = dict( + infield=gt_testing.field(in_range=(-10, 10), boundary=[(0, 0), (0, 0), (0, 0)]), + outfield=gt_testing.field(in_range=(-10, 10), boundary=[(0, 0), (0, 0), (0, 0)]), + ) + + def definition(infield, outfield): + + with computation(PARALLEL), interval(...): + + if infield > 0: + if infield < 0.5: + outfield = 1 + else: + outfield = 2 + else: + outfield = 3 + + def validation(infield, outfield, *, domain, origin, **kwargs): + outfield[...] = 3 + outfield[...] = np.choose(infield > 0, [outfield, 2]) + outfield[...] = np.choose(np.logical_and(infield > 0, infield < 0.5), [outfield, 1]) + + +class TestTernaryOp(gt_testing.StencilTestSuite): + + dtypes = (np.float_,) + domain_range = [(1, 15), (2, 15), (1, 15)] + backends = ["numpy", "gtx86", "gtmc"] + symbols = dict( + infield=gt_testing.field(in_range=(-10, 10), boundary=[(0, 0), (0, 1), (0, 0)]), + outfield=gt_testing.field(in_range=(-10, 10), boundary=[(0, 0), (0, 0), (0, 0)]), + ) + + def definition(infield, outfield): + + with computation(PARALLEL), interval(...): + outfield = (infield > 0.0) * infield + (infield <= 0.0) * (-infield[0, 1, 0]) + + def validation(infield, outfield, *, domain, origin, **kwargs): + outfield[...] = np.choose(infield[:, :-1, :] > 0, [-infield[:, 1:, :], infield[:, :-1, :]])