Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 76 additions & 3 deletions src/gt4py/backend/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(self, *args, interval_k_start_name, interval_k_end_name, **kwargs):
super().__init__(*args, **kwargs)
self.interval_k_start_name = interval_k_start_name
self.interval_k_end_name = interval_k_end_name
self.conditions_depth = 0

def _make_field_origin(self, name: str, origin=None):
if origin is None:
Expand Down Expand Up @@ -174,9 +175,8 @@ def visit_StencilImplementation(self, node: gt_ir.StencilImplementation):
def visit_TernaryOpExpr(self, node: gt_ir.TernaryOpExpr):
then_fmt = "({})" if isinstance(node.then_expr, gt_ir.CompositeExpr) else "{}"
else_fmt = "({})" if isinstance(node.else_expr, gt_ir.CompositeExpr) else "{}"
# source = "np.vectorize(lambda cond, then_expr, else_expr:then_expr if cond else else_expr)({condition}, {then_expr}, {else_expr})".format(

source = "{np}.choose({condition},[{else_expr}, {then_expr}], out={np}.empty({np}.max(({np}.asanyarray({condition}).shape,{np}.asanyarray({then_expr}).shape,{np}.asanyarray({else_expr}).shape),axis=0), dtype={np}.{dtype}))".format(
source = "vectorized_ternary_op(condition={condition}, then_expr={then_expr}, else_expr={else_expr}, dtype={np}.{dtype})".format(
condition=self.visit(node.condition),
then_expr=then_fmt.format(self.visit(node.then_expr)),
else_expr=else_fmt.format(self.visit(node.else_expr)),
Expand All @@ -187,7 +187,59 @@ def visit_TernaryOpExpr(self, node: gt_ir.TernaryOpExpr):
return source

def visit_If(self, node: gt_ir.If):
raise NotImplementedError("The numpy backend does not support runtime if statements.")
sources = []
self.conditions_depth += 1
sources.append(
"__condition_{level} = {condition}".format(
level=self.conditions_depth, condition=self.visit(node.condition)
)
)

stmts = [
*[(True, stmt) for stmt in node.main_body.stmts],
*[(False, stmt) for stmt in node.else_body.stmts],
]

for is_if, stmt in stmts:

if isinstance(stmt, gt_ir.Assign):
condition = (
(
"{np}.logical_and(".format(np=self.numpy_prefix)
+ ", ".join(
[
"__condition_{level}".format(level=i + 1)
for i in range(self.conditions_depth)
]
)
+ ")"
)
if self.conditions_depth > 1
else "__condition_1"
)

target = self.visit(stmt.target)
value = self.visit(stmt.value)
sources.append(
"{target} = vectorized_ternary_op(condition={condition}, then_expr={then_expr}, else_expr={else_expr}, dtype={np}.{dtype})".format(
condition=condition,
target=target,
then_expr=value if is_if else target,
else_expr=target if is_if else value,
dtype=stmt.target.data_type.dtype.name,
np=self.numpy_prefix,
)
)
else:
stmt_sources = self.visit(stmt)
if isinstance(stmt_sources, list):
sources.extend(stmt_sources)
else:
sources.append(stmt_sources)

self.conditions_depth -= 1
# return "\n".join(sources)
return sources


class NumPyGenerator(gt_backend.BaseGenerator):
Expand All @@ -206,6 +258,27 @@ def __init__(self, backend_class, options):
interval_k_end_name="interval_k_end",
)

def generate_module_members(self):
source = """
def vectorized_ternary_op(*, condition, then_expr, else_expr, dtype):
return np.choose(
condition,
[else_expr, then_expr],
out=np.empty(
np.max(
(
np.asanyarray(condition).shape,
np.asanyarray(then_expr).shape,
np.asanyarray(else_expr).shape,
),
axis=0,
),
dtype=dtype,
),
)
"""
return source

def generate_implementation(self):
sources = gt_text.TextBlock(indent_size=self.TEMPLATE_INDENT_SIZE)
self.source_generator(self.performance_ir, sources)
Expand Down
1 change: 0 additions & 1 deletion src/gt4py/backend/python_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,4 +280,3 @@ def visit_StencilImplementation(self, node: gt_ir.StencilImplementation):
for stage in group.stages:
self.visit(stage, iteration_order=multi_stage.iteration_order)
self.sources.append("")

100 changes: 100 additions & 0 deletions tests/test_integration/test_suites.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,103 @@ def validation(u, diffusion, *, weight, domain, origin, **kwargs):
diffusion[...] = u[2:-2, 2:-2, :] - weight * (
flux_i[1:, :, :] - flux_i[:-1, :, :] + flux_j[:, 1:, :] - flux_j[:, :-1, :]
)


class TestRuntimeIfFlat(gt_testing.StencilTestSuite):
"""Tests runtime ifs.
"""

dtypes = (np.float_,)
domain_range = [(1, 15), (1, 15), (1, 15)]
backends = ["debug", "numpy", "gtx86"]
symbols = dict(
outfield=gt_testing.field(in_range=(-10, 10), boundary=[(0, 0), (0, 0), (0, 0)]),
)

def definition(outfield):

with computation(PARALLEL), interval(...):

if 1:
outfield = 1
else:
outfield = 2

def validation(outfield, *, domain, origin, **kwargs):
outfield[...] = 1


class TestRuntimeIfNested(gt_testing.StencilTestSuite):
"""Tests nested runtime ifs
"""

dtypes = (np.float_,)
domain_range = [(1, 15), (1, 15), (1, 15)]
backends = ["debug", "numpy", "gtx86"]
symbols = dict(
outfield=gt_testing.field(in_range=(-10, 10), boundary=[(0, 0), (0, 0), (0, 0)]),
)

def definition(outfield):

with computation(PARALLEL), interval(...):

if 1:
if 0:
outfield = 1
else:
outfield = 2
else:
outfield = 3

def validation(outfield, *, domain, origin, **kwargs):
outfield[...] = 2


class TestRuntimeIfNestedDataDependent(gt_testing.StencilTestSuite):
"""Tests nested runtime ifs, where not the same branch is used for all data points.
"""

dtypes = (np.float_,)
domain_range = [(1, 15), (1, 15), (1, 15)]
backends = ["debug", "numpy", "gtx86"]
symbols = dict(
infield=gt_testing.field(in_range=(-10, 10), boundary=[(0, 0), (0, 0), (0, 0)]),
outfield=gt_testing.field(in_range=(-10, 10), boundary=[(0, 0), (0, 0), (0, 0)]),
)

def definition(infield, outfield):

with computation(PARALLEL), interval(...):

if infield > 0:
if infield < 0.5:
outfield = 1
else:
outfield = 2
else:
outfield = 3

def validation(infield, outfield, *, domain, origin, **kwargs):
outfield[...] = 3
outfield[...] = np.choose(infield > 0, [outfield, 2])
outfield[...] = np.choose(np.logical_and(infield > 0, infield < 0.5), [outfield, 1])


class TestTernaryOp(gt_testing.StencilTestSuite):

dtypes = (np.float_,)
domain_range = [(1, 15), (2, 15), (1, 15)]
backends = ["numpy", "gtx86", "gtmc"]
symbols = dict(
infield=gt_testing.field(in_range=(-10, 10), boundary=[(0, 0), (0, 1), (0, 0)]),
outfield=gt_testing.field(in_range=(-10, 10), boundary=[(0, 0), (0, 0), (0, 0)]),
)

def definition(infield, outfield):

with computation(PARALLEL), interval(...):
outfield = (infield > 0.0) * infield + (infield <= 0.0) * (-infield[0, 1, 0])

def validation(infield, outfield, *, domain, origin, **kwargs):
outfield[...] = np.choose(infield[:, :-1, :] > 0, [-infield[:, 1:, :], infield[:, :-1, :]])