diff --git a/src/gt4py/next/iterator/transforms/inline_literal.py b/src/gt4py/next/iterator/transforms/inline_literal.py new file mode 100644 index 0000000000..1a30aeba7f --- /dev/null +++ b/src/gt4py/next/iterator/transforms/inline_literal.py @@ -0,0 +1,81 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Mapping + +from gt4py import eve +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im + + +class ReplaceLiterals(eve.PreserveLocationVisitor, eve.NodeTranslator): + PRESERVED_ANNEX_ATTRS = ("type", "domain") + + def visit_FunCall(self, node: ir.FunCall, *, symbol_map: Mapping[str, ir.Literal]): + if cpm.is_call_to(node, "deref"): + assert len(node.args) == 1 + if ( + isinstance(node.args[0], ir.SymRef) + and (symbol_name := str(node.args[0].id)) in symbol_map + ): + return symbol_map[symbol_name] + + return self.generic_visit(node, symbol_map=symbol_map) + + def visit_SymRef(self, node: ir.SymRef, *, symbol_map: Mapping[str, ir.Literal]): + return symbol_map.get(str(node.id), node) + + +class InlineLiteral(eve.NodeTranslator): + """Inline literal arguments (constants) of field operators into the lambda expression.""" + + PRESERVED_ANNEX_ATTRS = ("domain", "type") + + def visit_FunCall(self, node: ir.FunCall) -> ir.Node: + node = self.generic_visit(node) + + if cpm.is_applied_as_fieldop(node): + assert len(node.fun.args) in {1, 2} + + lambda_params = [] + if isinstance(node.fun.args[0], ir.Lambda): + lambda_node = node.fun.args[0] + elif cpm.is_call_to(node.fun.args[0], "scan"): + assert isinstance(node.fun.args[0].args[0], ir.Lambda) + lambda_node = node.fun.args[0].args[0] + lambda_params.append(lambda_node.params[0]) + else: + return node + + fun_args = [] + symbol_map = {} + pstart = len(lambda_params) + for lambda_param, fun_arg in zip(lambda_node.params[pstart:], node.args, strict=True): + if isinstance(fun_arg, ir.Literal): + symbol_name = str(lambda_param.id) + symbol_map[symbol_name] = fun_arg + else: + fun_args.append(fun_arg) + lambda_params.append(lambda_param) + + if symbol_map: + domain = node.fun.args[1] if len(node.fun.args) == 2 else None + lambda_expr = ReplaceLiterals().visit(lambda_node.expr, symbol_map=symbol_map) + lambda_node = im.lambda_(*lambda_params)(lambda_expr) + if isinstance(node.fun.args[0], ir.Lambda): + return im.as_fieldop(lambda_node, domain)(*fun_args) + else: + scan_expr = im.scan( + lambda_node, node.fun.args[0].args[1], node.fun.args[0].args[2] + ) + return im.as_fieldop(scan_expr, domain)(*fun_args) + return node + + @classmethod + def apply(cls, node: ir.Node) -> ir.Node: + return cls().visit(node) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py index 1d0e20cad9..52cc31018f 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py @@ -31,7 +31,11 @@ domain_utils, ir_makers as im, ) -from gt4py.next.iterator.transforms import prune_casts as ir_prune_casts, symbol_ref_utils +from gt4py.next.iterator.transforms import ( + inline_literal, + prune_casts as ir_prune_casts, + symbol_ref_utils, +) from gt4py.next.iterator.type_system import inference as gtir_type_inference from gt4py.next.program_processors.runners.dace import sdfg_args as gtx_dace_args from gt4py.next.program_processors.runners.dace.lowering import ( @@ -1338,6 +1342,7 @@ def build_sdfg_from_gtir( if ir.declarations: raise NotImplementedError("Temporaries not supported yet by GTIR DaCe backend.") + ir = inline_literal.InlineLiteral().visit(ir) ir = gtir_type_inference.infer(ir, offset_provider_type=offset_provider_type) ir = ir_prune_casts.PruneCasts().visit(ir) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py index b8f958f6ab..9fc17d47ac 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py @@ -685,6 +685,10 @@ def translate_scan( # for output connections, we create temporary arrays that contain the computation # results of a column slice for each point in the horizontal domain + if isinstance(lambda_output, tuple) and not isinstance(node.annex.domain, tuple): + domain_tree = gtx_utils.tree_map(lambda x: node.annex.domain)(lambda_output) + else: + domain_tree = node.annex.domain output_tree = gtx_utils.tree_map( lambda output_data, output_domain: _handle_dataflow_result_of_nested_sdfg( sdfg_builder=sdfg_builder, @@ -694,9 +698,9 @@ def translate_scan( inner_data=output_data, field_domain=output_domain, ) - )(lambda_output, node.annex.domain) + )(lambda_output, domain_tree) # we call a helper method to create a map scope that will compute the entire field return _create_scan_field_operator( - ctx, field_domain, node.type, sdfg_builder, input_edges, output_tree, node.annex.domain + ctx, field_domain, node.type, sdfg_builder, input_edges, output_tree, domain_tree ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_literal.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_literal.py new file mode 100644 index 0000000000..b0da1598e4 --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_literal.py @@ -0,0 +1,44 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py import next as gtx +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms import inline_literal +from gt4py.next.type_system import type_specifications as ts + + +def test_inline_literal_fieldop(): + IDim = gtx.Dimension("IDim") + x_ref = im.ref("x", ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32))) + testee = im.op_as_fieldop("plus")(x_ref, 1.0) + expected = im.as_fieldop(im.lambda_("__arg0")(im.plus(im.deref("__arg0"), 1.0)))(x_ref) + actual = inline_literal.InlineLiteral.apply(testee) + assert actual == expected + + +def test_inline_literal_scan(): + IDim = gtx.Dimension("IDim") + x_ref = im.ref("x", ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32))) + testee = im.as_fieldop( + im.scan( + im.lambda_("state", "inp", "val")( + im.plus("state", im.multiplies_(im.deref("inp"), "val")) + ), + True, + 0.0, + ) + )(x_ref, 2.0) + expected = im.as_fieldop( + im.scan( + im.lambda_("state", "inp")(im.plus("state", im.multiplies_(im.deref("inp"), 2.0))), + True, + 0.0, + ) + )(x_ref) + actual = inline_literal.InlineLiteral.apply(testee) + assert actual == expected