Skip to content
81 changes: 81 additions & 0 deletions src/gt4py/next/iterator/transforms/inline_literal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Mapping

from gt4py import eve
from gt4py.next.iterator import ir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im


class ReplaceLiterals(eve.PreserveLocationVisitor, eve.NodeTranslator):
PRESERVED_ANNEX_ATTRS = ("type", "domain")

def visit_FunCall(self, node: ir.FunCall, *, symbol_map: Mapping[str, ir.Literal]):
if cpm.is_call_to(node, "deref"):
assert len(node.args) == 1
if (
isinstance(node.args[0], ir.SymRef)
and (symbol_name := str(node.args[0].id)) in symbol_map
):
return symbol_map[symbol_name]

return self.generic_visit(node, symbol_map=symbol_map)

def visit_SymRef(self, node: ir.SymRef, *, symbol_map: Mapping[str, ir.Literal]):
return symbol_map.get(str(node.id), node)


class InlineLiteral(eve.NodeTranslator):
"""Inline literal arguments (constants) of field operators into the lambda expression."""

PRESERVED_ANNEX_ATTRS = ("domain", "type")

def visit_FunCall(self, node: ir.FunCall) -> ir.Node:
node = self.generic_visit(node)

if cpm.is_applied_as_fieldop(node):
assert len(node.fun.args) in {1, 2}

lambda_params = []
if isinstance(node.fun.args[0], ir.Lambda):
lambda_node = node.fun.args[0]
elif cpm.is_call_to(node.fun.args[0], "scan"):
assert isinstance(node.fun.args[0].args[0], ir.Lambda)
lambda_node = node.fun.args[0].args[0]
lambda_params.append(lambda_node.params[0])
else:
return node

fun_args = []
symbol_map = {}
pstart = len(lambda_params)
for lambda_param, fun_arg in zip(lambda_node.params[pstart:], node.args, strict=True):
if isinstance(fun_arg, ir.Literal):
symbol_name = str(lambda_param.id)
symbol_map[symbol_name] = fun_arg
else:
fun_args.append(fun_arg)
lambda_params.append(lambda_param)

if symbol_map:
domain = node.fun.args[1] if len(node.fun.args) == 2 else None
lambda_expr = ReplaceLiterals().visit(lambda_node.expr, symbol_map=symbol_map)
lambda_node = im.lambda_(*lambda_params)(lambda_expr)
if isinstance(node.fun.args[0], ir.Lambda):
return im.as_fieldop(lambda_node, domain)(*fun_args)
else:
scan_expr = im.scan(
lambda_node, node.fun.args[0].args[1], node.fun.args[0].args[2]
)
return im.as_fieldop(scan_expr, domain)(*fun_args)
return node

@classmethod
def apply(cls, node: ir.Node) -> ir.Node:
return cls().visit(node)
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@
domain_utils,
ir_makers as im,
)
from gt4py.next.iterator.transforms import prune_casts as ir_prune_casts, symbol_ref_utils
from gt4py.next.iterator.transforms import (
inline_literal,
prune_casts as ir_prune_casts,
symbol_ref_utils,
)
from gt4py.next.iterator.type_system import inference as gtir_type_inference
from gt4py.next.program_processors.runners.dace import sdfg_args as gtx_dace_args
from gt4py.next.program_processors.runners.dace.lowering import (
Expand Down Expand Up @@ -1338,6 +1342,7 @@ def build_sdfg_from_gtir(
if ir.declarations:
raise NotImplementedError("Temporaries not supported yet by GTIR DaCe backend.")

ir = inline_literal.InlineLiteral().visit(ir)
ir = gtir_type_inference.infer(ir, offset_provider_type=offset_provider_type)
ir = ir_prune_casts.PruneCasts().visit(ir)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,10 @@ def translate_scan(

# for output connections, we create temporary arrays that contain the computation
# results of a column slice for each point in the horizontal domain
if isinstance(lambda_output, tuple) and not isinstance(node.annex.domain, tuple):
domain_tree = gtx_utils.tree_map(lambda x: node.annex.domain)(lambda_output)
else:
domain_tree = node.annex.domain
output_tree = gtx_utils.tree_map(
lambda output_data, output_domain: _handle_dataflow_result_of_nested_sdfg(
sdfg_builder=sdfg_builder,
Expand All @@ -694,9 +698,9 @@ def translate_scan(
inner_data=output_data,
field_domain=output_domain,
)
)(lambda_output, node.annex.domain)
)(lambda_output, domain_tree)

# we call a helper method to create a map scope that will compute the entire field
return _create_scan_field_operator(
ctx, field_domain, node.type, sdfg_builder, input_edges, output_tree, node.annex.domain
ctx, field_domain, node.type, sdfg_builder, input_edges, output_tree, domain_tree
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from gt4py import next as gtx
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.transforms import inline_literal
from gt4py.next.type_system import type_specifications as ts


def test_inline_literal_fieldop():
IDim = gtx.Dimension("IDim")
x_ref = im.ref("x", ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32)))
testee = im.op_as_fieldop("plus")(x_ref, 1.0)
expected = im.as_fieldop(im.lambda_("__arg0")(im.plus(im.deref("__arg0"), 1.0)))(x_ref)
actual = inline_literal.InlineLiteral.apply(testee)
assert actual == expected


def test_inline_literal_scan():
IDim = gtx.Dimension("IDim")
x_ref = im.ref("x", ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32)))
testee = im.as_fieldop(
im.scan(
im.lambda_("state", "inp", "val")(
im.plus("state", im.multiplies_(im.deref("inp"), "val"))
),
True,
0.0,
)
)(x_ref, 2.0)
expected = im.as_fieldop(
im.scan(
im.lambda_("state", "inp")(im.plus("state", im.multiplies_(im.deref("inp"), 2.0))),
True,
0.0,
)
)(x_ref)
actual = inline_literal.InlineLiteral.apply(testee)
assert actual == expected