From 4f271fb2513bc5ed7344838af13308a9450c9c68 Mon Sep 17 00:00:00 2001 From: Linus Groner Date: Tue, 3 Dec 2019 15:13:12 +0100 Subject: [PATCH 1/3] Implemented runtime conditionals for numpy backend --- src/gt4py/backend/numpy_backend.py | 91 ++++++++++++++++++++++++++- tests/test_integration/test_suites.py | 81 ++++++++++++++++++++++++ 2 files changed, 169 insertions(+), 3 deletions(-) diff --git a/src/gt4py/backend/numpy_backend.py b/src/gt4py/backend/numpy_backend.py index 4b6c86d5e4..229ba12cf8 100644 --- a/src/gt4py/backend/numpy_backend.py +++ b/src/gt4py/backend/numpy_backend.py @@ -27,6 +27,7 @@ def __init__(self, *args, interval_k_start_name, interval_k_end_name, **kwargs): super().__init__(*args, **kwargs) self.interval_k_start_name = interval_k_start_name self.interval_k_end_name = interval_k_end_name + self.conditions_depth = 0 def _make_field_origin(self, name: str, origin=None): if origin is None: @@ -174,9 +175,8 @@ def visit_StencilImplementation(self, node: gt_ir.StencilImplementation): def visit_TernaryOpExpr(self, node: gt_ir.TernaryOpExpr): then_fmt = "({})" if isinstance(node.then_expr, gt_ir.CompositeExpr) else "{}" else_fmt = "({})" if isinstance(node.else_expr, gt_ir.CompositeExpr) else "{}" - # source = "np.vectorize(lambda cond, then_expr, else_expr:then_expr if cond else else_expr)({condition}, {then_expr}, {else_expr})".format( - source = "{np}.choose({condition},[{else_expr}, {then_expr}], out={np}.empty({np}.max(({np}.asanyarray({condition}).shape,{np}.asanyarray({then_expr}).shape,{np}.asanyarray({else_expr}).shape),axis=0), dtype={np}.{dtype}))".format( + source = "VectorizedTernaryOp(condition={condition}, then_expr={then_expr}, else_expr={else_expr}, dtype={np}.{dtype})".format( condition=self.visit(node.condition), then_expr=then_fmt.format(self.visit(node.then_expr)), else_expr=else_fmt.format(self.visit(node.else_expr)), @@ -187,7 +187,71 @@ def visit_TernaryOpExpr(self, node: gt_ir.TernaryOpExpr): return source def visit_If(self, node: gt_ir.If): - raise NotImplementedError("The numpy backend does not support runtime if statements.") + sources = [] + self.conditions_depth += 1 + sources.append( + "__condition_{level} = {condition}".format( + level=self.conditions_depth, condition=self.visit(node.condition) + ) + ) + + for stmt in node.main_body.stmts: + if isinstance(stmt, gt_ir.Assign): + sources.append( + "{target} = VectorizedTernaryOp(condition={condition}, then_expr={then_expr}, else_expr={target}, dtype={np}.{dtype})".format( + condition="{np}.logical_and(".format(np=self.numpy_prefix) + + ", ".join( + [ + "__condition_{level}".format(level=i + 1) + for i in range(self.conditions_depth) + ] + ) + + ")" + if self.conditions_depth > 1 + else "__condition_1", + target=self.visit(stmt.target), + then_expr=self.visit(stmt.value), + dtype=stmt.target.data_type.dtype.name, + np=self.numpy_prefix, + ) + ) + else: + stmt_sources = self.visit(stmt) + if isinstance(stmt_sources, list): + sources.extend(stmt_sources) + else: + sources.append(stmt_sources) + + for stmt in node.else_body.stmts: + if isinstance(stmt, gt_ir.Assign): + sources.append( + "{target} = VectorizedTernaryOp(condition={np}.logical_not({condition}), then_expr={then_expr}, else_expr={target}, dtype={np}.{dtype})".format( + condition="{np}.logical_and(".format(np=self.numpy_prefix) + + ", ".join( + [ + "__condition_{level}".format(level=i + 1) + for i in range(self.conditions_depth) + ] + ) + + ")" + if self.conditions_depth > 1 + else "__condition_1", + target=self.visit(stmt.target), + then_expr=self.visit(stmt.value), + dtype=stmt.target.data_type.dtype.name, + np=self.numpy_prefix, + ) + ) + else: + stmt_sources = self.visit(stmt) + if isinstance(stmt_sources, list): + sources.extend(stmt_sources) + else: + sources.append(stmt_sources) + + self.conditions_depth -= 1 + # return "\n".join(sources) + return sources class NumPyGenerator(gt_backend.BaseGenerator): @@ -206,6 +270,27 @@ def __init__(self, backend_class, options): interval_k_end_name="interval_k_end", ) + def generate_module_members(self): + source = """ +def VectorizedTernaryOp(*, condition, then_expr, else_expr, dtype): + return np.choose( + condition, + [else_expr, then_expr], + out=np.empty( + np.max( + ( + np.asanyarray(condition).shape, + np.asanyarray(then_expr).shape, + np.asanyarray(else_expr).shape, + ), + axis=0, + ), + dtype=dtype, + ), + ) +""" + return source + def generate_implementation(self): sources = gt_text.TextBlock(indent_size=self.TEMPLATE_INDENT_SIZE) self.source_generator(self.performance_ir, sources) diff --git a/tests/test_integration/test_suites.py b/tests/test_integration/test_suites.py index 1ac1873d0a..817e402f93 100644 --- a/tests/test_integration/test_suites.py +++ b/tests/test_integration/test_suites.py @@ -358,3 +358,84 @@ def validation(u, diffusion, *, weight, domain, origin, **kwargs): diffusion[...] = u[2:-2, 2:-2, :] - weight * ( flux_i[1:, :, :] - flux_i[:-1, :, :] + flux_j[:, 1:, :] - flux_j[:, :-1, :] ) + + +class TestRuntimeIfFlat(gt_testing.StencilTestSuite): + """Diffusion in a horizontal 2D plane . + """ + + dtypes = (np.float_,) + domain_range = [(1, 15), (1, 15), (1, 15)] + backends = ["debug", "numpy", "gtx86"] + symbols = dict( + outfield=gt_testing.field(in_range=(-10, 10), boundary=[(0, 0), (0, 0), (0, 0)]), + ) + + def definition(outfield): + + with computation(PARALLEL), interval(...): + + if 1: + outfield = 1 + else: + outfield = 2 + + def validation(outfield, *, domain, origin, **kwargs): + outfield[...] = 1 + + +class TestRuntimeIfNested(gt_testing.StencilTestSuite): + """Diffusion in a horizontal 2D plane . + """ + + dtypes = (np.float_,) + domain_range = [(1, 15), (1, 15), (1, 15)] + backends = ["debug", "numpy", "gtx86"] + symbols = dict( + outfield=gt_testing.field(in_range=(-10, 10), boundary=[(0, 0), (0, 0), (0, 0)]), + ) + + def definition(outfield): + + with computation(PARALLEL), interval(...): + + if 1: + if 0: + outfield = 1 + else: + outfield = 2 + else: + outfield = 3 + + def validation(outfield, *, domain, origin, **kwargs): + outfield[...] = 2 + + +class TestRuntimeIfNestedDataDependent(gt_testing.StencilTestSuite): + """Diffusion in a horizontal 2D plane . + """ + + dtypes = (np.float_,) + domain_range = [(1, 15), (1, 15), (1, 15)] + backends = ["debug", "numpy", "gtx86"] + symbols = dict( + infield=gt_testing.field(in_range=(-10, 10), boundary=[(0, 0), (0, 0), (0, 0)]), + outfield=gt_testing.field(in_range=(-10, 10), boundary=[(0, 0), (0, 0), (0, 0)]), + ) + + def definition(infield, outfield): + + with computation(PARALLEL), interval(...): + + if infield > 0: + if infield < 0.5: + outfield = 1 + else: + outfield = 2 + else: + outfield = 3 + + def validation(infield, outfield, *, domain, origin, **kwargs): + outfield[...] = 3 + outfield[...] = np.choose(infield > 0, [outfield, 2]) + outfield[...] = np.choose(np.logical_and(infield > 0, infield < 0.5), [outfield, 1]) From aa91224270faed86c55fa048b80ed7167f4cf7e1 Mon Sep 17 00:00:00 2001 From: Linus Groner Date: Fri, 6 Dec 2019 17:18:00 +0100 Subject: [PATCH 2/3] Renaming and Code reuse for ternary --- src/gt4py/backend/numpy_backend.py | 54 ++++++++++++--------------- tests/test_integration/test_suites.py | 26 +++++++++++-- 2 files changed, 46 insertions(+), 34 deletions(-) diff --git a/src/gt4py/backend/numpy_backend.py b/src/gt4py/backend/numpy_backend.py index 229ba12cf8..cda75da25b 100644 --- a/src/gt4py/backend/numpy_backend.py +++ b/src/gt4py/backend/numpy_backend.py @@ -176,7 +176,7 @@ def visit_TernaryOpExpr(self, node: gt_ir.TernaryOpExpr): then_fmt = "({})" if isinstance(node.then_expr, gt_ir.CompositeExpr) else "{}" else_fmt = "({})" if isinstance(node.else_expr, gt_ir.CompositeExpr) else "{}" - source = "VectorizedTernaryOp(condition={condition}, then_expr={then_expr}, else_expr={else_expr}, dtype={np}.{dtype})".format( + source = "vectorized_ternary_op(condition={condition}, then_expr={then_expr}, else_expr={else_expr}, dtype={np}.{dtype})".format( condition=self.visit(node.condition), then_expr=then_fmt.format(self.visit(node.then_expr)), else_expr=else_fmt.format(self.visit(node.else_expr)), @@ -195,11 +195,17 @@ def visit_If(self, node: gt_ir.If): ) ) - for stmt in node.main_body.stmts: + stmts = [ + *[(True, stmt) for stmt in node.main_body.stmts], + *[(False, stmt) for stmt in node.else_body.stmts], + ] + + for is_if, stmt in stmts: + if isinstance(stmt, gt_ir.Assign): - sources.append( - "{target} = VectorizedTernaryOp(condition={condition}, then_expr={then_expr}, else_expr={target}, dtype={np}.{dtype})".format( - condition="{np}.logical_and(".format(np=self.numpy_prefix) + condition = ( + ( + "{np}.logical_and(".format(np=self.numpy_prefix) + ", ".join( [ "__condition_{level}".format(level=i + 1) @@ -207,35 +213,21 @@ def visit_If(self, node: gt_ir.If): ] ) + ")" - if self.conditions_depth > 1 - else "__condition_1", - target=self.visit(stmt.target), - then_expr=self.visit(stmt.value), - dtype=stmt.target.data_type.dtype.name, - np=self.numpy_prefix, ) + if self.conditions_depth > 1 + else "__condition_1" ) - else: - stmt_sources = self.visit(stmt) - if isinstance(stmt_sources, list): - sources.extend(stmt_sources) - else: - sources.append(stmt_sources) - for stmt in node.else_body.stmts: - if isinstance(stmt, gt_ir.Assign): + condition = ( + condition + if is_if + else "{np}.logical_not({condition})".format( + np=self.numpy_prefix, condition=condition + ) + ) sources.append( - "{target} = VectorizedTernaryOp(condition={np}.logical_not({condition}), then_expr={then_expr}, else_expr={target}, dtype={np}.{dtype})".format( - condition="{np}.logical_and(".format(np=self.numpy_prefix) - + ", ".join( - [ - "__condition_{level}".format(level=i + 1) - for i in range(self.conditions_depth) - ] - ) - + ")" - if self.conditions_depth > 1 - else "__condition_1", + "{target} = vectorized_ternary_op(condition={condition}, then_expr={then_expr}, else_expr={target}, dtype={np}.{dtype})".format( + condition=condition, target=self.visit(stmt.target), then_expr=self.visit(stmt.value), dtype=stmt.target.data_type.dtype.name, @@ -272,7 +264,7 @@ def __init__(self, backend_class, options): def generate_module_members(self): source = """ -def VectorizedTernaryOp(*, condition, then_expr, else_expr, dtype): +def vectorized_ternary_op(*, condition, then_expr, else_expr, dtype): return np.choose( condition, [else_expr, then_expr], diff --git a/tests/test_integration/test_suites.py b/tests/test_integration/test_suites.py index 817e402f93..ad1ba56440 100644 --- a/tests/test_integration/test_suites.py +++ b/tests/test_integration/test_suites.py @@ -361,7 +361,7 @@ def validation(u, diffusion, *, weight, domain, origin, **kwargs): class TestRuntimeIfFlat(gt_testing.StencilTestSuite): - """Diffusion in a horizontal 2D plane . + """Tests runtime ifs. """ dtypes = (np.float_,) @@ -385,7 +385,7 @@ def validation(outfield, *, domain, origin, **kwargs): class TestRuntimeIfNested(gt_testing.StencilTestSuite): - """Diffusion in a horizontal 2D plane . + """Tests nested runtime ifs """ dtypes = (np.float_,) @@ -412,7 +412,7 @@ def validation(outfield, *, domain, origin, **kwargs): class TestRuntimeIfNestedDataDependent(gt_testing.StencilTestSuite): - """Diffusion in a horizontal 2D plane . + """Tests nested runtime ifs, where not the same branch is used for all data points. """ dtypes = (np.float_,) @@ -439,3 +439,23 @@ def validation(infield, outfield, *, domain, origin, **kwargs): outfield[...] = 3 outfield[...] = np.choose(infield > 0, [outfield, 2]) outfield[...] = np.choose(np.logical_and(infield > 0, infield < 0.5), [outfield, 1]) + + +class TestTernaryOp(gt_testing.StencilTestSuite): + + + dtypes = (np.float_,) + domain_range = [(1, 15), (2, 15), (1, 15)] + backends = ["numpy", "gtx86", "gtmc"] + symbols = dict( + infield=gt_testing.field(in_range=(-10, 10), boundary=[(0, 0), (0, 1), (0, 0)]), + outfield=gt_testing.field(in_range=(-10, 10), boundary=[(0, 0), (0, 0), (0, 0)]), + ) + + def definition(infield, outfield): + + with computation(PARALLEL), interval(...): + outfield = (infield > 0.0) * infield + (infield <= 0.0) * (-infield[0, 1, 0]) + + def validation(infield, outfield, *, domain, origin, **kwargs): + outfield[...] = np.choose(infield[:, :-1, :] > 0, [-infield[:, 1:, :], infield[:, :-1, :]]) From 34abe0c1d6cace899c1d3e868774cb709b9a6144 Mon Sep 17 00:00:00 2001 From: Linus Groner Date: Mon, 9 Dec 2019 11:44:58 +0100 Subject: [PATCH 3/3] numpy runtime ifs: else case implemented by switching branches in ternary --- src/gt4py/backend/numpy_backend.py | 16 ++++++---------- src/gt4py/backend/python_generator.py | 1 - tests/test_integration/test_suites.py | 1 - 3 files changed, 6 insertions(+), 12 deletions(-) diff --git a/src/gt4py/backend/numpy_backend.py b/src/gt4py/backend/numpy_backend.py index cda75da25b..27aa0fd327 100644 --- a/src/gt4py/backend/numpy_backend.py +++ b/src/gt4py/backend/numpy_backend.py @@ -218,18 +218,14 @@ def visit_If(self, node: gt_ir.If): else "__condition_1" ) - condition = ( - condition - if is_if - else "{np}.logical_not({condition})".format( - np=self.numpy_prefix, condition=condition - ) - ) + target = self.visit(stmt.target) + value = self.visit(stmt.value) sources.append( - "{target} = vectorized_ternary_op(condition={condition}, then_expr={then_expr}, else_expr={target}, dtype={np}.{dtype})".format( + "{target} = vectorized_ternary_op(condition={condition}, then_expr={then_expr}, else_expr={else_expr}, dtype={np}.{dtype})".format( condition=condition, - target=self.visit(stmt.target), - then_expr=self.visit(stmt.value), + target=target, + then_expr=value if is_if else target, + else_expr=target if is_if else value, dtype=stmt.target.data_type.dtype.name, np=self.numpy_prefix, ) diff --git a/src/gt4py/backend/python_generator.py b/src/gt4py/backend/python_generator.py index 4c442a8b5a..1adf97f6b9 100644 --- a/src/gt4py/backend/python_generator.py +++ b/src/gt4py/backend/python_generator.py @@ -280,4 +280,3 @@ def visit_StencilImplementation(self, node: gt_ir.StencilImplementation): for stage in group.stages: self.visit(stage, iteration_order=multi_stage.iteration_order) self.sources.append("") - diff --git a/tests/test_integration/test_suites.py b/tests/test_integration/test_suites.py index ad1ba56440..41763f5940 100644 --- a/tests/test_integration/test_suites.py +++ b/tests/test_integration/test_suites.py @@ -443,7 +443,6 @@ def validation(infield, outfield, *, domain, origin, **kwargs): class TestTernaryOp(gt_testing.StencilTestSuite): - dtypes = (np.float_,) domain_range = [(1, 15), (2, 15), (1, 15)] backends = ["numpy", "gtx86", "gtmc"]